Skip to content

Commit de0df73

Browse files
committed
Add Embeddings::write for saving embeddings
1 parent 6869add commit de0df73

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ features = ["extension-module"]
1414

1515
[dependencies]
1616
failure = "0.1"
17-
rust2vec = "0.5"
17+
rust2vec = "0.5.1"
1818
toml = "0.4"

src/lib.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
use std::cell::RefCell;
44
use std::fs::File;
5-
use std::io::BufReader;
5+
use std::io::{BufReader, BufWriter};
66
use std::rc::Rc;
77

88
use failure::Error;
@@ -237,6 +237,24 @@ impl PyEmbeddings {
237237

238238
Ok(r)
239239
}
240+
241+
/// Write the embeddings to a finalfusion file.
242+
fn write(&self, filename: &str) -> PyResult<()> {
243+
let f = File::create(filename)?;
244+
let mut writer = BufWriter::new(f);
245+
246+
let embeddings = self.embeddings.borrow();
247+
248+
use EmbeddingsWrap::*;
249+
match &*embeddings {
250+
View(e) => e
251+
.write_embeddings(&mut writer)
252+
.map_err(|err| exceptions::IOError::py_err(err.to_string())),
253+
NonView(e) => e
254+
.write_embeddings(&mut writer)
255+
.map_err(|err| exceptions::IOError::py_err(err.to_string())),
256+
}
257+
}
240258
}
241259

242260
#[pyproto]

0 commit comments

Comments
 (0)