Skip to content

Commit c6c049e

Browse files
committed
Add storage module.
Split storage-related methods into its own module and add PyStorage as an interface to it. Add shape method to PyStorage.
1 parent ab7ef61 commit c6c049e

File tree

4 files changed

+112
-48
lines changed

4 files changed

+112
-48
lines changed

src/embeddings.rs

Lines changed: 7 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ use finalfusion::io as ffio;
1212
use finalfusion::prelude::*;
1313
use finalfusion::similarity::*;
1414
use itertools::Itertools;
15-
use ndarray::Array2;
16-
use numpy::{IntoPyArray, NpyDataType, PyArray1, PyArray2, ToPyArray};
15+
use numpy::{IntoPyArray, NpyDataType, PyArray1};
1716
use pyo3::class::iter::PyIterProtocol;
1817
use pyo3::prelude::*;
1918
use pyo3::types::{PyAny, PyTuple};
2019
use pyo3::{exceptions, PyMappingProtocol};
2120
use toml::{self, Value};
2221

22+
use crate::storage::PyStorage;
2323
use crate::{EmbeddingsWrap, PyEmbeddingIterator, PyVocab, PyWordSimilarity};
2424

2525
/// finalfusion embeddings.
@@ -34,24 +34,6 @@ pub struct PyEmbeddings {
3434
embeddings: Rc<RefCell<EmbeddingsWrap>>,
3535
}
3636

37-
impl PyEmbeddings {
38-
/// Copy storage to an array.
39-
///
40-
/// This should only be used for storage types that do not provide
41-
/// an ndarray view that can be copied trivially, such as quantized
42-
/// storage.
43-
fn copy_storage_to_array(storage: &dyn Storage) -> Array2<f32> {
44-
let (rows, dims) = storage.shape();
45-
46-
let mut array = Array2::<f32>::zeros((rows, dims));
47-
for idx in 0..rows {
48-
array.row_mut(idx).assign(&storage.embedding(idx).as_view());
49-
}
50-
51-
array
52-
}
53-
}
54-
5537
#[pymethods]
5638
impl PyEmbeddings {
5739
/// Load embeddings from the given `path`.
@@ -156,6 +138,11 @@ impl PyEmbeddings {
156138
Ok(PyVocab::new(self.embeddings.clone()))
157139
}
158140

141+
/// Get the model's storage.
142+
fn storage(&self) -> PyStorage {
143+
PyStorage::new(self.embeddings.clone())
144+
}
145+
159146
/// Perform an anology query.
160147
///
161148
/// This returns words for the analogy query *w1* is to *w2*
@@ -222,31 +209,6 @@ impl PyEmbeddings {
222209
})
223210
}
224211

225-
/// Copy the entire embeddings matrix.
226-
fn matrix_copy(&self) -> Py<PyArray2<f32>> {
227-
let embeddings = self.embeddings.borrow();
228-
229-
use EmbeddingsWrap::*;
230-
let gil = pyo3::Python::acquire_gil();
231-
let matrix_view = match &*embeddings {
232-
View(e) => e.storage().view(),
233-
NonView(e) => match e.storage() {
234-
StorageWrap::MmapArray(mmap) => mmap.view(),
235-
StorageWrap::NdArray(array) => array.view(),
236-
StorageWrap::QuantizedArray(quantized) => {
237-
let array = Self::copy_storage_to_array(quantized.as_ref());
238-
return array.to_pyarray(gil.python()).to_owned();
239-
}
240-
StorageWrap::MmapQuantizedArray(quantized) => {
241-
let array = Self::copy_storage_to_array(quantized);
242-
return array.to_pyarray(gil.python()).to_owned();
243-
}
244-
},
245-
};
246-
247-
matrix_view.to_pyarray(gil.python()).to_owned()
248-
}
249-
250212
/// Embeddings metadata.
251213
#[getter]
252214
fn metadata(&self) -> PyResult<Option<String>> {

src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ use similarity::PyWordSimilarity;
1717
mod vocab;
1818
use vocab::PyVocab;
1919

20+
mod storage;
21+
use storage::PyStorage;
22+
2023
/// This is a Python module for using finalfusion embeddings.
2124
///
2225
/// finalfusion is a format for word embeddings that supports words,
@@ -25,6 +28,7 @@ use vocab::PyVocab;
2528
fn finalfusion(_py: Python, m: &PyModule) -> PyResult<()> {
2629
m.add_class::<PyEmbeddings>()?;
2730
m.add_class::<PyEmbedding>()?;
31+
m.add_class::<PyStorage>()?;
2832
m.add_class::<PyWordSimilarity>()?;
2933
m.add_class::<PyVocab>()?;
3034
Ok(())

src/storage.rs

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
use std::cell::RefCell;
2+
use std::rc::Rc;
3+
4+
use finalfusion::prelude::{Storage, StorageView, StorageWrap};
5+
use ndarray::Array2;
6+
use numpy::{PyArray1, PyArray2, ToPyArray};
7+
use pyo3::class::sequence::PySequenceProtocol;
8+
use pyo3::exceptions;
9+
use pyo3::prelude::*;
10+
11+
use crate::EmbeddingsWrap;
12+
13+
/// finalfusion vocab.
14+
#[pyclass(name=Storage)]
15+
pub struct PyStorage {
16+
embeddings: Rc<RefCell<EmbeddingsWrap>>,
17+
}
18+
19+
impl PyStorage {
20+
pub fn new(embeddings: Rc<RefCell<EmbeddingsWrap>>) -> Self {
21+
PyStorage { embeddings }
22+
}
23+
/// Copy storage to an array.
24+
///
25+
/// This should only be used for storage types that do not provide
26+
/// an ndarray view that can be copied trivially, such as quantized
27+
/// storage.
28+
fn copy_storage_to_array(storage: &dyn Storage) -> Array2<f32> {
29+
let (rows, dims) = storage.shape();
30+
31+
let mut array = Array2::<f32>::zeros((rows, dims));
32+
for idx in 0..rows {
33+
array.row_mut(idx).assign(&storage.embedding(idx).as_view());
34+
}
35+
36+
array
37+
}
38+
}
39+
40+
#[pymethods]
41+
impl PyStorage {
42+
/// Copy the entire embeddings matrix.
43+
fn matrix_copy(&self) -> Py<PyArray2<f32>> {
44+
let embeddings = self.embeddings.borrow();
45+
46+
use EmbeddingsWrap::*;
47+
let gil = pyo3::Python::acquire_gil();
48+
let matrix_view = match &*embeddings {
49+
View(e) => e.storage().view(),
50+
NonView(e) => match e.storage() {
51+
StorageWrap::MmapArray(mmap) => mmap.view(),
52+
StorageWrap::NdArray(array) => array.view(),
53+
StorageWrap::QuantizedArray(quantized) => {
54+
let array = Self::copy_storage_to_array(quantized.as_ref());
55+
return array.to_pyarray(gil.python()).to_owned();
56+
}
57+
StorageWrap::MmapQuantizedArray(quantized) => {
58+
let array = Self::copy_storage_to_array(quantized);
59+
return array.to_pyarray(gil.python()).to_owned();
60+
}
61+
},
62+
};
63+
64+
matrix_view.to_pyarray(gil.python()).to_owned()
65+
}
66+
67+
/// Get the shape of the storage.
68+
fn shape(&self) -> (usize, usize) {
69+
let embeddings = self.embeddings.borrow();
70+
embeddings.storage().shape()
71+
}
72+
}
73+
74+
#[pyproto]
75+
impl PySequenceProtocol for PyStorage {
76+
fn __len__(&self) -> PyResult<usize> {
77+
let embeds = self.embeddings.borrow();
78+
Ok(embeds.storage().shape().0)
79+
}
80+
81+
fn __getitem__(&self, idx: isize) -> PyResult<Py<PyArray1<f32>>> {
82+
let embeds = self.embeddings.borrow();
83+
let storage = embeds.storage();
84+
85+
if idx >= storage.shape().0 as isize || idx < 0 {
86+
Err(exceptions::IndexError::py_err("list index out of range"))
87+
} else {
88+
let gil = Python::acquire_gil();
89+
Ok(storage
90+
.embedding(idx as usize)
91+
.into_owned()
92+
.to_pyarray(gil.python())
93+
.to_owned())
94+
}
95+
}
96+
}

tests/test_embeddings.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,19 @@ def test_embeddings(embeddings_fifu, embeddings_text, embeddings_text_dims):
1717
assert len(embeddings_fifu.vocab()) == 7
1818
assert len(embeddings_text.vocab()) == 7
1919
assert len(embeddings_text_dims.vocab()) == 7
20-
20+
fifu_storage = embeddings_fifu.storage()
2121
# Check that the finalfusion embeddings have the correct dimensionality
2222
# The correct dimensionality of the other embedding types is asserted
2323
# in the pairwise comparisons below.
24-
assert embeddings_fifu.matrix_copy().shape == (7, 10)
24+
assert fifu_storage.shape() == (7, 10)
2525

26-
for embedding in embeddings_fifu:
26+
for embedding, storage_row in zip(embeddings_fifu, fifu_storage):
2727
assert numpy.allclose(
2828
embedding.embedding, embeddings_text[embedding.word]), "FiFu and text embedding mismatch"
2929
assert numpy.allclose(
3030
embedding.embedding, embeddings_text_dims[embedding.word]), "FiFu and textdims embedding mismatch"
31+
assert numpy.allclose(
32+
embedding.embedding, storage_row), "FiFu and storage row mismatch"
3133

3234

3335
def test_embeddings_pq(similarity_fifu, similarity_pq):

0 commit comments

Comments
 (0)