Skip to content

Commit 1f2f840

Browse files
sebpuetzDaniël de Kok
authored andcommitted
Add a method to copy the entire matrix.
1 parent 9b1ba2a commit 1f2f840

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,6 @@ features = ["extension-module"]
2222
failure = "0.1"
2323
finalfusion = "0.5"
2424
libc = "0.2"
25+
ndarray = "0.12"
2526
numpy = "0.5"
2627
toml = "0.4"

src/lib.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ use failure::Error;
99
use finalfusion::metadata::Metadata;
1010
use finalfusion::prelude::*;
1111
use finalfusion::similarity::*;
12-
use numpy::{IntoPyArray, PyArray1};
12+
use ndarray::Array2;
13+
use numpy::{IntoPyArray, PyArray1, PyArray2};
1314
use pyo3::class::{basic::PyObjectProtocol, iter::PyIterProtocol};
1415
use pyo3::exceptions;
1516
use pyo3::prelude::*;
@@ -164,6 +165,34 @@ impl PyEmbeddings {
164165
}
165166
}
166167

168+
/// Copy the entire embeddings matrix.
169+
///
170+
/// Raises an exception if the embeddings are quantized.
171+
fn matrix_copy(&self) -> PyResult<Py<PyArray2<f32>>> {
172+
let embeddings = self.embeddings.borrow();
173+
174+
use EmbeddingsWrap::*;
175+
let matrix = match &*embeddings {
176+
View(e) => e.storage().view().to_owned(),
177+
NonView(e) => match e.storage() {
178+
StorageWrap::MmapArray(mmap) => mmap.view().to_owned(),
179+
StorageWrap::NdArray(array) => array.0.to_owned(),
180+
StorageWrap::QuantizedArray(quantized) => {
181+
let (rows, dims) = quantized.shape();
182+
let mut array = Array2::<f32>::zeros((rows, dims));
183+
for idx in 0..rows {
184+
array
185+
.row_mut(idx)
186+
.assign(&quantized.embedding(idx).as_view());
187+
}
188+
array
189+
}
190+
},
191+
};
192+
let gil = pyo3::Python::acquire_gil();
193+
Ok(matrix.into_pyarray(gil.python()).to_owned())
194+
}
195+
167196
/// Embeddings metadata.
168197
#[getter]
169198
fn metadata(&self) -> PyResult<Option<String>> {

0 commit comments

Comments
 (0)