Skip to content

Commit 8bb481f

Browse files
sebpuetzDaniël de Kok
authored andcommitted
Make copied matrix mutable in python.
1 parent b27c5d9 commit 8bb481f

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/embeddings.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use finalfusion::prelude::*;
1313
use finalfusion::similarity::*;
1414
use itertools::Itertools;
1515
use ndarray::Array2;
16-
use numpy::{IntoPyArray, NpyDataType, PyArray1, PyArray2};
16+
use numpy::{IntoPyArray, NpyDataType, PyArray1, PyArray2, ToPyArray};
1717
use pyo3::class::iter::PyIterProtocol;
1818
use pyo3::prelude::*;
1919
use pyo3::types::{PyAny, PyList, PySet, PyTuple};
@@ -181,11 +181,12 @@ impl PyEmbeddings {
181181
let embeddings = self.embeddings.borrow();
182182

183183
use EmbeddingsWrap::*;
184-
let matrix = match &*embeddings {
185-
View(e) => e.storage().view().to_owned(),
184+
let gil = pyo3::Python::acquire_gil();
185+
let matrix_view = match &*embeddings {
186+
View(e) => e.storage().view(),
186187
NonView(e) => match e.storage() {
187-
StorageWrap::MmapArray(mmap) => mmap.view().to_owned(),
188-
StorageWrap::NdArray(array) => array.0.to_owned(),
188+
StorageWrap::MmapArray(mmap) => mmap.view(),
189+
StorageWrap::NdArray(array) => array.0.view(),
189190
StorageWrap::QuantizedArray(quantized) => {
190191
let (rows, dims) = quantized.shape();
191192
let mut array = Array2::<f32>::zeros((rows, dims));
@@ -194,12 +195,11 @@ impl PyEmbeddings {
194195
.row_mut(idx)
195196
.assign(&quantized.embedding(idx).as_view());
196197
}
197-
array
198+
return array.to_pyarray(gil.python()).to_owned();
198199
}
199200
},
200201
};
201-
let gil = pyo3::Python::acquire_gil();
202-
matrix.into_pyarray(gil.python()).to_owned()
202+
matrix_view.to_pyarray(gil.python()).to_owned()
203203
}
204204

205205
/// Embeddings metadata.

0 commit comments

Comments
 (0)