Skip to content

Commit 4293fce

Browse files
sebpuetzDaniël de Kok
authored andcommitted
Simplify API with convenience methods.
1 parent 4fefb6a commit 4293fce

File tree

1 file changed

+27
-32
lines changed

1 file changed

+27
-32
lines changed

src/lib.rs

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,24 @@ enum EmbeddingsWrap {
6161
View(Embeddings<VocabWrap, StorageViewWrap>),
6262
}
6363

64+
impl EmbeddingsWrap {
65+
fn storage(&self) -> &Storage {
66+
use EmbeddingsWrap::*;
67+
match self {
68+
NonView(e) => e.storage(),
69+
View(e) => e.storage(),
70+
}
71+
}
72+
73+
fn vocab(&self) -> &VocabWrap {
74+
use EmbeddingsWrap::*;
75+
match self {
76+
NonView(e) => e.vocab(),
77+
View(e) => e.vocab(),
78+
}
79+
}
80+
}
81+
6482
/// finalfusion embeddings.
6583
#[pyclass(name=Embeddings)]
6684
struct PyEmbeddings {
@@ -329,12 +347,9 @@ impl PyVocab {
329347
fn item_to_indices(&self, key: String) -> PyResult<PyObject> {
330348
let embeds = self.embeddings.borrow();
331349

332-
use EmbeddingsWrap::*;
333-
let indices = match &*embeds {
334-
View(e) => e.vocab().idx(key.as_str()),
335-
NonView(e) => e.vocab().idx(key.as_str()),
336-
};
337-
indices
350+
embeds
351+
.vocab()
352+
.idx(key.as_str())
338353
.map(|idx| {
339354
let gil = pyo3::Python::acquire_gil();
340355
match idx {
@@ -350,28 +365,17 @@ impl PyVocab {
350365
impl PySequenceProtocol for PyVocab {
351366
fn __len__(&self) -> PyResult<usize> {
352367
let embeds = self.embeddings.borrow();
353-
354-
use EmbeddingsWrap::*;
355-
match &*embeds {
356-
View(e) => Ok(e.vocab().words().len()),
357-
NonView(e) => Ok(e.vocab().words().len()),
358-
}
368+
Ok(embeds.vocab().len())
359369
}
360370

361371
fn __getitem__(&self, idx: isize) -> PyResult<String> {
362372
let embeds = self.embeddings.borrow();
373+
let words = embeds.vocab().words();
363374

364-
use EmbeddingsWrap::*;
365-
366-
let v = match &*embeds {
367-
View(e) => e.vocab(),
368-
NonView(e) => e.vocab(),
369-
};
370-
371-
if idx >= v.words().len() as isize || idx < 0 {
375+
if idx >= words.len() as isize || idx < 0 {
372376
Err(exceptions::IndexError::py_err("list index out of range"))
373377
} else {
374-
Ok(v.words()[idx as usize].clone())
378+
Ok(words[idx as usize].clone())
375379
}
376380
}
377381
}
@@ -408,20 +412,11 @@ impl PyIterProtocol for PyEmbeddingIterator {
408412
let slf = &mut *slf;
409413

410414
let embeddings = slf.embeddings.borrow();
411-
412-
use EmbeddingsWrap::*;
413-
let vocab = match &*embeddings {
414-
View(e) => e.vocab(),
415-
NonView(e) => e.vocab(),
416-
};
415+
let vocab = embeddings.vocab();
417416

418417
if slf.idx < vocab.len() {
419418
let word = vocab.words()[slf.idx].to_string();
420-
421-
let embed = match &*embeddings {
422-
View(e) => e.storage().embedding(slf.idx),
423-
NonView(e) => e.storage().embedding(slf.idx),
424-
};
419+
let embed = embeddings.storage().embedding(slf.idx);
425420

426421
slf.idx += 1;
427422

0 commit comments

Comments
 (0)