Skip to content

Commit 6d4a949

Browse files
danieldkDaniël de Kok
authored andcommitted
Split lib.rs into several modules
1 parent 329da32 commit 6d4a949

File tree

6 files changed

+460
-417
lines changed

6 files changed

+460
-417
lines changed

src/embeddings.rs

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
use std::cell::RefCell;
2+
use std::fs::File;
3+
use std::io::{BufReader, BufWriter};
4+
use std::rc::Rc;
5+
6+
use failure::Error;
7+
use finalfusion::metadata::Metadata;
8+
use finalfusion::prelude::*;
9+
use finalfusion::similarity::*;
10+
use ndarray::Array2;
11+
use numpy::{IntoPyArray, PyArray1, PyArray2};
12+
use pyo3::class::iter::PyIterProtocol;
13+
use pyo3::exceptions;
14+
use pyo3::prelude::*;
15+
use toml::{self, Value};
16+
17+
use crate::{EmbeddingsWrap, PyEmbeddingIterator, PyVocab, PyWordSimilarity};
18+
19+
/// finalfusion embeddings.
20+
#[pyclass(name=Embeddings)]
21+
pub struct PyEmbeddings {
22+
// The use of Rc + RefCell should be safe in this crate:
23+
//
24+
// 1. Python is single-threaded.
25+
// 2. The only mutable borrow (in set_metadata) is limited
26+
// to its method scope.
27+
// 3. None of the methods returns borrowed embeddings.
28+
embeddings: Rc<RefCell<EmbeddingsWrap>>,
29+
}
30+
31+
#[pymethods]
32+
impl PyEmbeddings {
33+
/// Load embeddings from the given `path`.
34+
///
35+
/// When the `mmap` argument is `True`, the embedding matrix is
36+
/// not loaded into memory, but memory mapped. This results in
37+
/// lower memory use and shorter load times, while sacrificing
38+
/// some query efficiency.
39+
#[new]
40+
#[args(mmap = false)]
41+
fn __new__(obj: &PyRawObject, path: &str, mmap: bool) -> PyResult<()> {
42+
// First try to load embeddings with viewable storage. If that
43+
// fails, attempt to load the embeddings as non-viewable
44+
// storage.
45+
let embeddings = match load_embeddings(path, mmap) {
46+
Ok(e) => Rc::new(RefCell::new(EmbeddingsWrap::View(e))),
47+
Err(_) => load_embeddings(path, mmap)
48+
.map(|e| Rc::new(RefCell::new(EmbeddingsWrap::NonView(e))))
49+
.map_err(|err| exceptions::IOError::py_err(err.to_string()))?,
50+
};
51+
52+
obj.init(PyEmbeddings { embeddings });
53+
54+
Ok(())
55+
}
56+
57+
/// Get the model's vocabulary.
58+
fn vocab(&self) -> PyResult<PyVocab> {
59+
Ok(PyVocab::new(self.embeddings.clone()))
60+
}
61+
62+
/// Perform an anology query.
63+
///
64+
/// This returns words for the analogy query *w1* is to *w2*
65+
/// as *w3* is to ?.
66+
#[args(limit = 10)]
67+
fn analogy(
68+
&self,
69+
py: Python,
70+
word1: &str,
71+
word2: &str,
72+
word3: &str,
73+
limit: usize,
74+
) -> PyResult<Vec<PyObject>> {
75+
use EmbeddingsWrap::*;
76+
let embeddings = self.embeddings.borrow();
77+
let embeddings = match &*embeddings {
78+
View(e) => e,
79+
NonView(_) => {
80+
return Err(exceptions::ValueError::py_err(
81+
"Analogy queries are not supported for this type of embedding matrix",
82+
));
83+
}
84+
};
85+
86+
let results = match embeddings.analogy(word1, word2, word3, limit) {
87+
Some(results) => results,
88+
None => return Err(exceptions::KeyError::py_err("Unknown word or n-grams")),
89+
};
90+
91+
let mut r = Vec::with_capacity(results.len());
92+
for ws in results {
93+
r.push(
94+
Py::new(
95+
py,
96+
PyWordSimilarity::new(ws.word.to_owned(), ws.similarity.into_inner()),
97+
)?
98+
.into_object(py),
99+
)
100+
}
101+
102+
Ok(r)
103+
}
104+
105+
/// Get the embedding for the given word.
106+
///
107+
/// If the word is not known, its representation is approximated
108+
/// using subword units.
109+
fn embedding(&self, word: &str) -> PyResult<Py<PyArray1<f32>>> {
110+
let embeddings = self.embeddings.borrow();
111+
112+
use EmbeddingsWrap::*;
113+
let embedding = match &*embeddings {
114+
View(e) => e.embedding(word),
115+
NonView(e) => e.embedding(word),
116+
};
117+
118+
match embedding {
119+
Some(embedding) => {
120+
let gil = pyo3::Python::acquire_gil();
121+
Ok(embedding.into_owned().into_pyarray(gil.python()).to_owned())
122+
}
123+
None => Err(exceptions::KeyError::py_err("Unknown word and n-grams")),
124+
}
125+
}
126+
127+
/// Copy the entire embeddings matrix.
128+
fn matrix_copy(&self) -> PyResult<Py<PyArray2<f32>>> {
129+
let embeddings = self.embeddings.borrow();
130+
131+
use EmbeddingsWrap::*;
132+
let matrix = match &*embeddings {
133+
View(e) => e.storage().view().to_owned(),
134+
NonView(e) => match e.storage() {
135+
StorageWrap::MmapArray(mmap) => mmap.view().to_owned(),
136+
StorageWrap::NdArray(array) => array.0.to_owned(),
137+
StorageWrap::QuantizedArray(quantized) => {
138+
let (rows, dims) = quantized.shape();
139+
let mut array = Array2::<f32>::zeros((rows, dims));
140+
for idx in 0..rows {
141+
array
142+
.row_mut(idx)
143+
.assign(&quantized.embedding(idx).as_view());
144+
}
145+
array
146+
}
147+
},
148+
};
149+
let gil = pyo3::Python::acquire_gil();
150+
Ok(matrix.into_pyarray(gil.python()).to_owned())
151+
}
152+
153+
/// Embeddings metadata.
154+
#[getter]
155+
fn metadata(&self) -> PyResult<Option<String>> {
156+
let embeddings = self.embeddings.borrow();
157+
158+
use EmbeddingsWrap::*;
159+
let metadata = match &*embeddings {
160+
View(e) => e.metadata(),
161+
NonView(e) => e.metadata(),
162+
};
163+
164+
match metadata.map(|v| toml::ser::to_string_pretty(&v.0)) {
165+
Some(Ok(toml)) => Ok(Some(toml)),
166+
Some(Err(err)) => Err(exceptions::IOError::py_err(format!(
167+
"Metadata is invalid TOML: {}",
168+
err
169+
))),
170+
None => Ok(None),
171+
}
172+
}
173+
174+
#[setter]
175+
fn set_metadata(&mut self, metadata: &str) -> PyResult<()> {
176+
let value = match metadata.parse::<Value>() {
177+
Ok(value) => value,
178+
Err(err) => {
179+
return Err(exceptions::ValueError::py_err(format!(
180+
"Metadata is invalid TOML: {}",
181+
err
182+
)));
183+
}
184+
};
185+
186+
let mut embeddings = self.embeddings.borrow_mut();
187+
188+
use EmbeddingsWrap::*;
189+
match &mut *embeddings {
190+
View(e) => e.set_metadata(Some(Metadata(value))),
191+
NonView(e) => e.set_metadata(Some(Metadata(value))),
192+
};
193+
194+
Ok(())
195+
}
196+
197+
/// Perform a similarity query.
198+
#[args(limit = 10)]
199+
fn similarity(&self, py: Python, word: &str, limit: usize) -> PyResult<Vec<PyObject>> {
200+
let embeddings = self.embeddings.borrow();
201+
202+
use EmbeddingsWrap::*;
203+
let embeddings = match &*embeddings {
204+
View(e) => e,
205+
NonView(_) => {
206+
return Err(exceptions::ValueError::py_err(
207+
"Similarity queries are not supported for this type of embedding matrix",
208+
));
209+
}
210+
};
211+
212+
let results = match embeddings.similarity(word, limit) {
213+
Some(results) => results,
214+
None => return Err(exceptions::KeyError::py_err("Unknown word and n-grams")),
215+
};
216+
217+
let mut r = Vec::with_capacity(results.len());
218+
for ws in results {
219+
r.push(
220+
Py::new(
221+
py,
222+
PyWordSimilarity::new(ws.word.to_owned(), ws.similarity.into_inner()),
223+
)?
224+
.into_object(py),
225+
)
226+
}
227+
228+
Ok(r)
229+
}
230+
231+
/// Write the embeddings to a finalfusion file.
232+
fn write(&self, filename: &str) -> PyResult<()> {
233+
let f = File::create(filename)?;
234+
let mut writer = BufWriter::new(f);
235+
236+
let embeddings = self.embeddings.borrow();
237+
238+
use EmbeddingsWrap::*;
239+
match &*embeddings {
240+
View(e) => e
241+
.write_embeddings(&mut writer)
242+
.map_err(|err| exceptions::IOError::py_err(err.to_string())),
243+
NonView(e) => e
244+
.write_embeddings(&mut writer)
245+
.map_err(|err| exceptions::IOError::py_err(err.to_string())),
246+
}
247+
}
248+
}
249+
250+
#[pyproto]
251+
impl PyIterProtocol for PyEmbeddings {
252+
fn __iter__(slf: PyRefMut<Self>) -> PyResult<PyObject> {
253+
let gil = Python::acquire_gil();
254+
let py = gil.python();
255+
let iter =
256+
Py::new(py, PyEmbeddingIterator::new(slf.embeddings.clone(), 0))?.into_object(py);
257+
258+
Ok(iter)
259+
}
260+
}
261+
262+
fn load_embeddings<S>(path: &str, mmap: bool) -> Result<Embeddings<VocabWrap, S>, Error>
263+
where
264+
Embeddings<VocabWrap, S>: ReadEmbeddings + MmapEmbeddings,
265+
{
266+
let f = File::open(path)?;
267+
let mut reader = BufReader::new(f);
268+
269+
let embeddings = if mmap {
270+
Embeddings::mmap_embeddings(&mut reader)?
271+
} else {
272+
Embeddings::read_embeddings(&mut reader)?
273+
};
274+
275+
Ok(embeddings)
276+
}

src/embeddings_wrap.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
use finalfusion::prelude::*;
2+
3+
pub enum EmbeddingsWrap {
4+
NonView(Embeddings<VocabWrap, StorageWrap>),
5+
View(Embeddings<VocabWrap, StorageViewWrap>),
6+
}
7+
8+
impl EmbeddingsWrap {
9+
pub fn storage(&self) -> &Storage {
10+
use EmbeddingsWrap::*;
11+
match self {
12+
NonView(e) => e.storage(),
13+
View(e) => e.storage(),
14+
}
15+
}
16+
17+
pub fn vocab(&self) -> &VocabWrap {
18+
use EmbeddingsWrap::*;
19+
match self {
20+
NonView(e) => e.vocab(),
21+
View(e) => e.vocab(),
22+
}
23+
}
24+
}

src/iter.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
use std::cell::RefCell;
2+
use std::rc::Rc;
3+
4+
use finalfusion::prelude::*;
5+
use numpy::{IntoPyArray, PyArray1};
6+
use pyo3::class::iter::PyIterProtocol;
7+
use pyo3::prelude::*;
8+
9+
use crate::EmbeddingsWrap;
10+
11+
#[pyclass(name=EmbeddingIterator)]
12+
pub struct PyEmbeddingIterator {
13+
embeddings: Rc<RefCell<EmbeddingsWrap>>,
14+
idx: usize,
15+
}
16+
17+
impl PyEmbeddingIterator {
18+
pub fn new(embeddings: Rc<RefCell<EmbeddingsWrap>>, idx: usize) -> Self {
19+
PyEmbeddingIterator { embeddings, idx }
20+
}
21+
}
22+
23+
#[pyproto]
24+
impl PyIterProtocol for PyEmbeddingIterator {
25+
fn __iter__(slf: PyRefMut<Self>) -> PyResult<Py<PyEmbeddingIterator>> {
26+
Ok(slf.into())
27+
}
28+
29+
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<(String, Py<PyArray1<f32>>)>> {
30+
let slf = &mut *slf;
31+
32+
let embeddings = slf.embeddings.borrow();
33+
let vocab = embeddings.vocab();
34+
35+
if slf.idx < vocab.len() {
36+
let word = vocab.words()[slf.idx].to_string();
37+
let embed = embeddings.storage().embedding(slf.idx);
38+
39+
slf.idx += 1;
40+
41+
let gil = pyo3::Python::acquire_gil();
42+
Ok(Some((
43+
word,
44+
embed.into_owned().into_pyarray(gil.python()).to_owned(),
45+
)))
46+
} else {
47+
Ok(None)
48+
}
49+
}
50+
}

0 commit comments

Comments
 (0)