Skip to content

Commit b094386

Browse files
danieldkDaniël de Kok
authored andcommitted
Implement model iteration.
The iterator is implemented in Python by iterating over vocab indices and extracting the corresponding tokens/embeddings.
1 parent 8abc5a0 commit b094386

File tree

2 files changed

+58
-11
lines changed

2 files changed

+58
-11
lines changed

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ name = "finalfrontier"
88
crate-type = ["cdylib"]
99

1010
[dependencies.pyo3]
11-
version = "0.4"
11+
version = "0.5.0-alpha.1"
1212
features = ["extension-module"]
1313

1414
[dependencies]
1515
failure = "0.1"
1616
finalfrontier = "0.2"
17+
ndarray = "0.11"

src/lib.rs

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22

33
extern crate failure;
44
extern crate finalfrontier;
5+
extern crate ndarray;
56
extern crate pyo3;
67

78
use std::fs::File;
89
use std::io::BufReader;
10+
use std::rc::Rc;
911

1012
use failure::Error;
1113
use finalfrontier::similarity::{Analogy, Similarity};
12-
use finalfrontier::{MmapModelBinary, ReadModelBinary};
14+
use finalfrontier::{MmapModelBinary, Model, ReadModelBinary};
15+
use ndarray::Axis;
1316
use pyo3::prelude::*;
1417

1518
/// This is a binding for finalfrontier.
@@ -57,7 +60,7 @@ impl PyObjectProtocol for PyWordSimilarity {
5760
/// A finalfrontier model.
5861
#[pyclass(name=Model)]
5962
struct PyModel {
60-
model: finalfrontier::Model,
63+
model: Rc<Model>,
6164
token: PyToken,
6265
}
6366

@@ -73,9 +76,9 @@ impl PyModel {
7376
#[args(mmap = false)]
7477
fn __new__(obj: &PyRawObject, path: &str, mmap: bool) -> PyResult<()> {
7578
let model = match load_model(path, mmap) {
76-
Ok(model) => model,
79+
Ok(model) => Rc::new(model),
7780
Err(err) => {
78-
return Err(exc::IOError::new(err.to_string()));
81+
return Err(exc::IOError::py_err(err.to_string()));
7982
}
8083
};
8184

@@ -97,7 +100,7 @@ impl PyModel {
97100
) -> PyResult<Vec<PyObject>> {
98101
let results = match self.model.analogy(word1, word2, word3, limit) {
99102
Some(results) => results,
100-
None => return Err(exc::KeyError::new("Unknown word and n-grams")),
103+
None => return Err(exc::KeyError::py_err("Unknown word and n-grams")),
101104
};
102105

103106
let mut r = Vec::with_capacity(results.len());
@@ -121,7 +124,7 @@ impl PyModel {
121124
fn embedding(&self, word: &str) -> PyResult<Vec<f32>> {
122125
match self.model.embedding(word) {
123126
Some(embedding) => Ok(embedding.to_vec()),
124-
None => Err(exc::KeyError::new("Unknown word and n-grams")),
127+
None => Err(exc::KeyError::py_err("Unknown word and n-grams")),
125128
}
126129
}
127130

@@ -130,7 +133,7 @@ impl PyModel {
130133
fn similarity(&self, py: Python, word: &str, limit: usize) -> PyResult<Vec<PyObject>> {
131134
let results = match self.model.similarity(word, limit) {
132135
Some(results) => results,
133-
None => return Err(exc::KeyError::new("Unknown word and n-grams")),
136+
None => return Err(exc::KeyError::py_err("Unknown word and n-grams")),
134137
};
135138

136139
let mut r = Vec::with_capacity(results.len());
@@ -148,14 +151,57 @@ impl PyModel {
148151
}
149152
}
150153

151-
fn load_model(path: &str, mmap: bool) -> Result<finalfrontier::Model, Error> {
154+
#[pyproto]
155+
impl PyIterProtocol for PyModel {
156+
fn __iter__(&mut self) -> PyResult<PyObject> {
157+
let gil = Python::acquire_gil();
158+
let py = gil.python();
159+
let iter = Py::new(py, |token| PyModelIterator {
160+
model: self.model.clone(),
161+
idx: 0,
162+
token,
163+
})?.into_object(py);
164+
165+
Ok(iter)
166+
}
167+
}
168+
169+
fn load_model(path: &str, mmap: bool) -> Result<Model, Error> {
152170
let f = File::open(path)?;
153171

154172
let model = if mmap {
155-
finalfrontier::Model::mmap_model_binary(f)?
173+
Model::mmap_model_binary(f)?
156174
} else {
157-
finalfrontier::Model::read_model_binary(&mut BufReader::new(f))?
175+
Model::read_model_binary(&mut BufReader::new(f))?
158176
};
159177

160178
Ok(model)
161179
}
180+
181+
#[pyclass(name=ModelIterator)]
182+
struct PyModelIterator {
183+
model: Rc<Model>,
184+
idx: usize,
185+
token: PyToken,
186+
}
187+
188+
#[pyproto]
189+
impl PyIterProtocol for PyModelIterator {
190+
fn __iter__(&mut self) -> PyResult<PyObject> {
191+
Ok(self.into())
192+
}
193+
194+
fn __next__(&mut self) -> PyResult<Option<(String, Vec<f32>)>> {
195+
let vocab = self.model.vocab();
196+
let embeddings = self.model.embedding_matrix();
197+
198+
if self.idx < vocab.len() {
199+
let word = vocab.words()[self.idx].word().to_string();
200+
let embed = embeddings.subview(Axis(0), self.idx).to_vec();
201+
self.idx += 1;
202+
Ok(Some((word, embed)))
203+
} else {
204+
Ok(None)
205+
}
206+
}
207+
}

0 commit comments

Comments
 (0)