Skip to content

Commit 680505f

Browse files
sebpuetzDaniël de Kok
authored andcommitted
Fix FromPyObject for SkipSet.
Fixes #51 by allowing all string-iterables in FromPyObject implementation for SkipSet.
1 parent 7b5b7ae commit 680505f

File tree

1 file changed

+11
-20
lines changed

1 file changed

+11
-20
lines changed

src/embeddings.rs

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ use ndarray::Array2;
1616
use numpy::{IntoPyArray, NpyDataType, PyArray1, PyArray2, ToPyArray};
1717
use pyo3::class::iter::PyIterProtocol;
1818
use pyo3::prelude::*;
19-
use pyo3::types::{PyAny, PyList, PySet, PyTuple};
20-
use pyo3::{exceptions, PyMappingProtocol, PyTypeInfo};
19+
use pyo3::types::{PyAny, PyTuple};
20+
use pyo3::{exceptions, PyMappingProtocol};
2121
use toml::{self, Value};
2222

2323
use crate::{EmbeddingsWrap, PyEmbeddingIterator, PyVocab, PyWordSimilarity};
@@ -461,24 +461,15 @@ struct Skips<'a>(HashSet<&'a str>);
461461

462462
impl<'a> FromPyObject<'a> for Skips<'a> {
463463
fn extract(ob: &'a PyAny) -> Result<Self, PyErr> {
464-
let mut set = ob
465-
.len()
466-
.map(|len| HashSet::with_capacity(len))
467-
.unwrap_or_default();
468-
469-
let iter = if <PySet as PyTypeInfo>::is_instance(ob) {
470-
ob.iter().unwrap()
471-
} else if <PyList as PyTypeInfo>::is_instance(ob) {
472-
ob.iter().unwrap()
473-
} else {
474-
return Err(exceptions::TypeError::py_err("Iterable expected"));
475-
};
476-
477-
for el in iter {
478-
set.insert(
479-
el?.extract()
480-
.map_err(|_| exceptions::TypeError::py_err("Expected String"))?,
481-
);
464+
let mut set = ob.len().map(HashSet::with_capacity).unwrap_or_default();
465+
for el in ob
466+
.iter()
467+
.map_err(|_| exceptions::TypeError::py_err("Iterable expected"))?
468+
{
469+
let el = el?;
470+
set.insert(el.extract().map_err(|_| {
471+
exceptions::TypeError::py_err(format!("Expected String not: {}", el))
472+
})?);
482473
}
483474
Ok(Skips(set))
484475
}

0 commit comments

Comments
 (0)