Skip to content

Commit 1231184

Browse files
sebpuetzDaniël de Kok
authored andcommitted
Fix doubly-wrapped option argument for embedding similarity.
1 parent 4f6db3c commit 1231184

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/embeddings.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -310,12 +310,12 @@ impl PyEmbeddings {
310310
}
311311

312312
/// Perform a similarity query based on a query embedding.
313-
#[args(limit = 10, skip = "None")]
313+
#[args(limit = 10, skip = "Skips(HashSet::new())")]
314314
fn embedding_similarity(
315315
&self,
316316
py: Python,
317317
embedding: PyEmbedding,
318-
skip: Option<Option<Skips>>,
318+
skip: Skips,
319319
limit: usize,
320320
) -> PyResult<Vec<PyObject>> {
321321
let embeddings = self.embeddings.borrow();
@@ -336,11 +336,7 @@ impl PyEmbeddings {
336336
)));
337337
}
338338

339-
let results = if let Some(Some(skip)) = skip {
340-
embeddings.embedding_similarity_masked(embedding, limit, &skip.0)
341-
} else {
342-
embeddings.embedding_similarity(embedding, limit)
343-
};
339+
let results = embeddings.embedding_similarity_masked(embedding, limit, &skip.0);
344340

345341
Self::similarity_results(
346342
py,
@@ -462,6 +458,9 @@ struct Skips<'a>(HashSet<&'a str>);
462458
impl<'a> FromPyObject<'a> for Skips<'a> {
463459
fn extract(ob: &'a PyAny) -> Result<Self, PyErr> {
464460
let mut set = ob.len().map(HashSet::with_capacity).unwrap_or_default();
461+
if ob.is_none() {
462+
return Ok(Skips(set));
463+
}
465464
for el in ob
466465
.iter()
467466
.map_err(|_| exceptions::TypeError::py_err("Iterable expected"))?

0 commit comments

Comments
 (0)