Skip to content

Commit 11980ac

Browse files
sebpuetzDaniël de Kok
authored andcommitted
Add convenience method to get viewable embeddings.
1 parent fdde90b commit 11980ac

File tree

2 files changed

+21
-22
lines changed

2 files changed

+21
-22
lines changed

src/embeddings.rs

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,13 @@ impl PyEmbeddings {
7575
limit: usize,
7676
mask: (bool, bool, bool),
7777
) -> PyResult<Vec<PyObject>> {
78-
use EmbeddingsWrap::*;
7978
let embeddings = self.embeddings.borrow();
80-
let embeddings = match &*embeddings {
81-
View(e) => e,
82-
NonView(_) => {
83-
return Err(exceptions::ValueError::py_err(
84-
"Analogy queries are not supported for this type of embedding matrix",
85-
));
86-
}
87-
};
79+
80+
let embeddings = embeddings.view().ok_or_else(|| {
81+
exceptions::ValueError::py_err(
82+
"Analogy queries are not supported for this type of embedding matrix",
83+
)
84+
})?;
8885

8986
let results = embeddings
9087
.analogy_masked([word1, word2, word3], [mask.0, mask.1, mask.2], limit)
@@ -216,20 +213,15 @@ impl PyEmbeddings {
216213
fn similarity(&self, py: Python, word: &str, limit: usize) -> PyResult<Vec<PyObject>> {
217214
let embeddings = self.embeddings.borrow();
218215

219-
use EmbeddingsWrap::*;
220-
let embeddings = match &*embeddings {
221-
View(e) => e,
222-
NonView(_) => {
223-
return Err(exceptions::ValueError::py_err(
224-
"Similarity queries are not supported for this type of embedding matrix",
225-
));
226-
}
227-
};
216+
let embeddings = embeddings.view().ok_or_else(|| {
217+
exceptions::ValueError::py_err(
218+
"Similarity queries are not supported for this type of embedding matrix",
219+
)
220+
})?;
228221

229-
let results = match embeddings.word_similarity(word, limit) {
230-
Some(results) => results,
231-
None => return Err(exceptions::KeyError::py_err("Unknown word and n-grams")),
232-
};
222+
let results = embeddings
223+
.word_similarity(word, limit)
224+
.ok_or_else(|| exceptions::KeyError::py_err("Unknown word and n-grams"))?;
233225

234226
let mut r = Vec::with_capacity(results.len());
235227
for ws in results {

src/embeddings_wrap.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,11 @@ impl EmbeddingsWrap {
3939
NonView(e) => e.embedding(word),
4040
}
4141
}
42+
43+
pub fn view(&self) -> Option<&Embeddings<VocabWrap, StorageViewWrap>> {
44+
match self {
45+
EmbeddingsWrap::NonView(_) => None,
46+
EmbeddingsWrap::View(storage) => Some(storage),
47+
}
48+
}
4249
}

0 commit comments

Comments
 (0)