Skip to content

Commit 8f34aa3

Browse files
committed
Properly implement PyProtocols for PyVocab.
In order to get arbitrary keys, PyMappingProtocol::__getitem__ needs to be implemented. To get O(1) __contains__, PySequenceProtocol::__contains__ needs to be implemented. To get proper Iteration support, PyIterProtocol::__iter__ needs to be implemented. PyO3/pyo3#611 This commit adds the correct implementation of the three traits to PyVocab.
1 parent e16c867 commit 8f34aa3

File tree

3 files changed

+153
-23
lines changed

3 files changed

+153
-23
lines changed

src/iter.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,37 @@ impl PyEmbedding {
8080
self.norm
8181
}
8282
}
83+
84+
#[pyclass(name=VocabIterator)]
85+
pub struct PyVocabIterator {
86+
embeddings: Rc<RefCell<EmbeddingsWrap>>,
87+
idx: usize,
88+
}
89+
90+
impl PyVocabIterator {
91+
pub fn new(embeddings: Rc<RefCell<EmbeddingsWrap>>, idx: usize) -> Self {
92+
PyVocabIterator { embeddings, idx }
93+
}
94+
}
95+
96+
#[pyproto]
97+
impl PyIterProtocol for PyVocabIterator {
98+
fn __iter__(slf: PyRefMut<Self>) -> PyResult<Py<PyVocabIterator>> {
99+
Ok(slf.into())
100+
}
101+
102+
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<String>> {
103+
let slf = &mut *slf;
104+
105+
let embeddings = slf.embeddings.borrow();
106+
let vocab = embeddings.vocab();
107+
108+
if slf.idx < vocab.words_len() {
109+
let word = vocab.words()[slf.idx].to_string();
110+
slf.idx += 1;
111+
Ok(Some(word))
112+
} else {
113+
Ok(None)
114+
}
115+
}
116+
}

src/vocab.rs

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ use std::rc::Rc;
44
use finalfusion::chunks::vocab::{NGramIndices, SubwordIndices, VocabWrap, WordIndex};
55
use finalfusion::prelude::*;
66
use pyo3::class::sequence::PySequenceProtocol;
7-
use pyo3::exceptions;
7+
use pyo3::exceptions::{IndexError, KeyError, ValueError};
88
use pyo3::prelude::*;
9+
use pyo3::{PyIterProtocol, PyMappingProtocol};
910

11+
use crate::iter::PyVocabIterator;
1012
use crate::EmbeddingsWrap;
1113

1214
type NGramIndex = (String, Option<usize>);
@@ -25,16 +27,18 @@ impl PyVocab {
2527

2628
#[pymethods]
2729
impl PyVocab {
28-
fn item_to_indices(&self, key: String) -> Option<PyObject> {
30+
#[args(default = "Python::acquire_gil().python().None()")]
31+
fn get(&self, key: &str, default: PyObject) -> Option<PyObject> {
2932
let embeds = self.embeddings.borrow();
30-
31-
embeds.vocab().idx(key.as_str()).map(|idx| {
32-
let gil = pyo3::Python::acquire_gil();
33-
match idx {
34-
WordIndex::Word(idx) => [idx].to_object(gil.python()),
35-
WordIndex::Subword(indices) => indices.to_object(gil.python()),
36-
}
37-
})
33+
let gil = pyo3::Python::acquire_gil();
34+
let idx = embeds.vocab().idx(key).map(|idx| match idx {
35+
WordIndex::Word(idx) => idx.to_object(gil.python()),
36+
WordIndex::Subword(indices) => indices.to_object(gil.python()),
37+
});
38+
if !default.is_none() && idx.is_none() {
39+
return Some(default);
40+
}
41+
idx
3842
}
3943

4044
fn ngram_indices(&self, word: &str) -> PyResult<Option<Vec<NGramIndex>>> {
@@ -44,7 +48,7 @@ impl PyVocab {
4448
VocabWrap::FinalfusionSubwordVocab(inner) => inner.ngram_indices(word),
4549
VocabWrap::FinalfusionNGramVocab(inner) => inner.ngram_indices(word),
4650
VocabWrap::SimpleVocab(_) => {
47-
return Err(exceptions::ValueError::py_err(
51+
return Err(ValueError::py_err(
4852
"querying n-gram indices is not supported for this vocabulary",
4953
))
5054
}
@@ -57,29 +61,71 @@ impl PyVocab {
5761
VocabWrap::FastTextSubwordVocab(inner) => Ok(inner.subword_indices(word)),
5862
VocabWrap::FinalfusionSubwordVocab(inner) => Ok(inner.subword_indices(word)),
5963
VocabWrap::FinalfusionNGramVocab(inner) => Ok(inner.subword_indices(word)),
60-
VocabWrap::SimpleVocab(_) => Err(exceptions::ValueError::py_err(
64+
VocabWrap::SimpleVocab(_) => Err(ValueError::py_err(
6165
"querying subwords' indices is not supported for this vocabulary",
6266
)),
6367
}
6468
}
6569
}
6670

67-
#[pyproto]
68-
impl PySequenceProtocol for PyVocab {
69-
fn __len__(&self) -> PyResult<usize> {
71+
impl PyVocab {
72+
fn str_to_indices(&self, query: &str) -> PyResult<WordIndex> {
7073
let embeds = self.embeddings.borrow();
71-
Ok(embeds.vocab().words_len())
74+
embeds
75+
.vocab()
76+
.idx(query)
77+
.ok_or_else(|| KeyError::py_err(format!("key not found: '{}'", query)))
7278
}
7379

74-
fn __getitem__(&self, idx: isize) -> PyResult<String> {
80+
fn validate_and_convert_isize_idx(&self, mut idx: isize) -> PyResult<usize> {
7581
let embeds = self.embeddings.borrow();
76-
let words = embeds.vocab().words();
82+
let vocab = embeds.vocab();
83+
if idx < 0 {
84+
idx += vocab.words_len() as isize;
85+
}
7786

78-
if idx >= words.len() as isize || idx < 0 {
79-
Err(exceptions::IndexError::py_err("list index out of range"))
87+
if idx >= vocab.words_len() as isize || idx < 0 {
88+
Err(IndexError::py_err("list index out of range"))
8089
} else {
81-
Ok(words[idx as usize].clone())
90+
Ok(idx as usize)
91+
}
92+
}
93+
}
94+
95+
#[pyproto]
96+
impl PyMappingProtocol for PyVocab {
97+
fn __getitem__(&self, query: PyObject) -> PyResult<PyObject> {
98+
let embeds = self.embeddings.borrow();
99+
let vocab = embeds.vocab();
100+
let gil = Python::acquire_gil();
101+
if let Ok(idx) = query.extract::<isize>(gil.python()) {
102+
let idx = self.validate_and_convert_isize_idx(idx)?;
103+
return Ok(vocab.words()[idx].clone().into_py(gil.python()));
104+
}
105+
106+
if let Ok(query) = query.extract::<String>(gil.python()) {
107+
return self.str_to_indices(&query).map(|idx| match idx {
108+
WordIndex::Subword(indices) => indices.into_py(gil.python()),
109+
WordIndex::Word(idx) => idx.into_py(gil.python()),
110+
});
82111
}
112+
113+
Err(KeyError::py_err("key must be integer or string"))
114+
}
115+
}
116+
117+
#[pyproto]
118+
impl PyIterProtocol for PyVocab {
119+
fn __iter__(slf: PyRefMut<Self>) -> PyResult<PyVocabIterator> {
120+
Ok(PyVocabIterator::new(slf.embeddings.clone(), 0))
121+
}
122+
}
123+
124+
#[pyproto]
125+
impl PySequenceProtocol for PyVocab {
126+
fn __len__(&self) -> PyResult<usize> {
127+
let embeds = self.embeddings.borrow();
128+
Ok(embeds.vocab().words_len())
83129
}
84130

85131
fn __contains__(&self, word: String) -> PyResult<bool> {

tests/test_vocab.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pytest
2+
13
TEST_NGRAM_INDICES = [
24
('tüb',
35
14),
@@ -53,9 +55,19 @@
5355
1007)]
5456

5557

56-
def test_embeddings_with_norms_oov(embeddings_fifu):
58+
def test_get(embeddings_text_dims):
59+
vocab = embeddings_text_dims.vocab()
60+
assert vocab.get("one") is 0
61+
62+
63+
def test_get_oov(embeddings_fifu):
64+
vocab = embeddings_fifu.vocab()
65+
assert vocab.get("Something out of vocabulary") is None
66+
67+
68+
def test_get_oov_with_default(embeddings_fifu):
5769
vocab = embeddings_fifu.vocab()
58-
assert vocab.item_to_indices("Something out of vocabulary") is None
70+
assert vocab.get("Something out of vocabulary", default=-1) == -1
5971

6072

6173
def test_ngram_indices(subword_fifu):
@@ -72,3 +84,41 @@ def test_subword_indices(subword_fifu):
7284
for subword_index, test_ngram_index in zip(
7385
subword_indices, TEST_NGRAM_INDICES):
7486
assert subword_index == test_ngram_index[1]
87+
88+
89+
def test_int_idx(embeddings_text_dims):
90+
vocab = embeddings_text_dims.vocab()
91+
assert vocab[0] == "one"
92+
93+
94+
def test_int_idx_out_of_range(embeddings_text_dims):
95+
vocab = embeddings_text_dims.vocab()
96+
with pytest.raises(IndexError):
97+
_ = vocab[42]
98+
99+
100+
def test_negative_int_idx(embeddings_text_dims):
101+
vocab = embeddings_text_dims.vocab()
102+
assert vocab[-1] == "seven"
103+
104+
105+
def test_negative_int_idx_out_of_range(embeddings_text_dims):
106+
vocab = embeddings_text_dims.vocab()
107+
with pytest.raises(IndexError):
108+
_ = vocab[-42]
109+
110+
111+
def test_string_idx(embeddings_text_dims):
112+
vocab = embeddings_text_dims.vocab()
113+
assert vocab["one"] == 0
114+
115+
116+
def test_string_oov(embeddings_text_dims):
117+
vocab = embeddings_text_dims.vocab()
118+
with pytest.raises(KeyError):
119+
vocab["definitely in vocab"]
120+
121+
122+
def test_string_oov_subwords(subword_fifu):
123+
vocab = subword_fifu.vocab()
124+
assert sorted(vocab["tübingen"]) == [x[1] for x in TEST_NGRAM_INDICES]

0 commit comments

Comments
 (0)