Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ serde = { version = "1.0", features = ["rc", "derive"] }
serde_json = "1.0"
libc = "0.2"
env_logger = "0.11"
pyo3 = { version = "0.25", features = ["abi3", "abi3-py39", "py-clone"] }
pyo3-async-runtimes = { version = "0.25", features = ["tokio-runtime"] }
pyo3 = { version = "0.26", features = ["abi3", "abi3-py39", "py-clone"] }
pyo3-async-runtimes = { version = "0.26", features = ["tokio-runtime"] }
tokio = { version = "1.47.1", features = ["rt", "rt-multi-thread", "macros", "signal"] }
once_cell = "1.19.0"
numpy = "0.25"
numpy = "0.26"
ndarray = "0.16"
itertools = "0.14"
ahash = { version = "0.8.11", features = ["serde"] }
Expand All @@ -28,7 +28,7 @@ path = "../../tokenizers"

[dev-dependencies]
tempfile = "3.10"
pyo3 = { version = "0.25", features = ["auto-initialize"] }
pyo3 = { version = "0.26", features = ["auto-initialize"] }

[features]
default = ["pyo3/extension-module"]
84 changes: 26 additions & 58 deletions bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,51 +40,25 @@ impl PyDecoder {
PyDecoder { decoder }
}

pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult<PyObject> {
pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
let base = self.clone();
Ok(match &self.decoder {
PyDecoderWrapper::Custom(_) => Py::new(py, base)?.into_pyobject(py)?.into_any().into(),
PyDecoderWrapper::Custom(_) => Py::new(py, base)?.into_any(),
PyDecoderWrapper::Wrapped(inner) => match &*inner.as_ref().read().unwrap() {
DecoderWrapper::Metaspace(_) => Py::new(py, (PyMetaspaceDec {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
DecoderWrapper::ByteFallback(_) => Py::new(py, (PyByteFallbackDec {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
DecoderWrapper::Strip(_) => Py::new(py, (PyStrip {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
DecoderWrapper::Fuse(_) => Py::new(py, (PyFuseDec {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
DecoderWrapper::Replace(_) => Py::new(py, (PyReplaceDec {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
DecoderWrapper::Sequence(_) => Py::new(py, (PySequenceDecoder {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
DecoderWrapper::Metaspace(_) => Py::new(py, (PyMetaspaceDec {}, base))?.into_any(),
DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?.into_any(),
DecoderWrapper::ByteFallback(_) => {
Py::new(py, (PyByteFallbackDec {}, base))?.into_any()
}
DecoderWrapper::Strip(_) => Py::new(py, (PyStrip {}, base))?.into_any(),
DecoderWrapper::Fuse(_) => Py::new(py, (PyFuseDec {}, base))?.into_any(),
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_any(),
DecoderWrapper::Replace(_) => Py::new(py, (PyReplaceDec {}, base))?.into_any(),
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_any(),
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_any(),
DecoderWrapper::Sequence(_) => {
Py::new(py, (PySequenceDecoder {}, base))?.into_any()
}
},
})
}
Expand All @@ -99,12 +73,12 @@ impl Decoder for PyDecoder {
#[pymethods]
impl PyDecoder {
#[staticmethod]
fn custom(decoder: PyObject) -> Self {
fn custom(decoder: Py<PyAny>) -> Self {
let decoder = PyDecoderWrapper::Custom(Arc::new(RwLock::new(CustomDecoder::new(decoder))));
PyDecoder::new(decoder)
}

fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
fn __getstate__(&self, py: Python) -> PyResult<Py<PyAny>> {
let data = serde_json::to_string(&self.decoder).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to pickle Decoder: {e}"
Expand All @@ -113,7 +87,7 @@ impl PyDecoder {
Ok(PyBytes::new(py, data.as_bytes()).into())
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
fn __setstate__(&mut self, py: Python, state: Py<PyAny>) -> PyResult<()> {
match state.extract::<&[u8]>(py) {
Ok(s) => {
self.decoder = serde_json::from_slice(s).map_err(|e| {
Expand Down Expand Up @@ -514,18 +488,18 @@ impl PySequenceDecoder {
}

pub(crate) struct CustomDecoder {
inner: PyObject,
inner: Py<PyAny>,
}

impl CustomDecoder {
pub(crate) fn new(inner: PyObject) -> Self {
pub(crate) fn new(inner: Py<PyAny>) -> Self {
CustomDecoder { inner }
}
}

impl Decoder for CustomDecoder {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
Python::with_gil(|py| {
Python::attach(|py| {
let decoded = self
.inner
.call_method(py, "decode", (tokens,), None)?
Expand All @@ -535,7 +509,7 @@ impl Decoder for CustomDecoder {
}

fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
Python::with_gil(|py| {
Python::attach(|py| {
let decoded = self
.inner
.call_method(py, "decode_chain", (tokens,), None)?
Expand Down Expand Up @@ -708,7 +682,7 @@ mod test {

#[test]
fn get_subtype() {
Python::with_gil(|py| {
Python::attach(|py| {
let py_dec = PyDecoder::new(Metaspace::default().into());
let py_meta = py_dec.get_as_subtype(py).unwrap();
assert_eq!("Metaspace", py_meta.bind(py).get_type().qualname().unwrap());
Expand All @@ -731,15 +705,9 @@ mod test {
_ => panic!("Expected wrapped, not custom."),
}

let obj = Python::with_gil(|py| {
let obj = Python::attach(|py| {
let py_msp = PyDecoder::new(Metaspace::default().into());
let obj: PyObject = Py::new(py, py_msp)
.unwrap()
.into_pyobject(py)
.unwrap()
.into_any()
.into();
obj
Py::new(py, py_msp).unwrap().into_any()
});
let py_seq = PyDecoderWrapper::Custom(Arc::new(RwLock::new(CustomDecoder::new(obj))));
assert!(serde_json::to_string(&py_seq).is_err());
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl PyEncoding {
}
}

fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
fn __getstate__(&self, py: Python) -> PyResult<Py<PyAny>> {
let data = serde_json::to_string(&self.encoding).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to pickle Encoding: {e}"
Expand All @@ -39,7 +39,7 @@ impl PyEncoding {
Ok(PyBytes::new(py, data.as_bytes()).into())
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
fn __setstate__(&mut self, py: Python, state: Py<PyAny>) -> PyResult<()> {
match state.extract::<&[u8]>(py) {
Ok(s) => {
self.encoding = serde_json::from_slice(s).map_err(|e| {
Expand Down
30 changes: 9 additions & 21 deletions bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,13 @@ pub struct PyModel {
}

impl PyModel {
pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult<PyObject> {
pub(crate) fn get_as_subtype<'py>(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
let base = self.clone();
Ok(match *self.model.as_ref().read().unwrap() {
ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
ModelWrapper::Unigram(_) => Py::new(py, (PyUnigram {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base))?.into_any(),
ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base))?.into_any(),
ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base))?.into_any(),
ModelWrapper::Unigram(_) => Py::new(py, (PyUnigram {}, base))?.into_any(),
})
}
}
Expand Down Expand Up @@ -111,14 +99,14 @@ impl PyModel {
}
}

fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
fn __getstate__(&self, py: Python) -> PyResult<Py<PyAny>> {
let data = serde_json::to_string(&self.model).map_err(|e| {
exceptions::PyException::new_err(format!("Error while attempting to pickle Model: {e}"))
})?;
Ok(PyBytes::new(py, data.as_bytes()).into())
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
fn __setstate__(&mut self, py: Python, state: Py<PyAny>) -> PyResult<()> {
match state.extract::<&[u8]>(py) {
Ok(s) => {
self.model = serde_json::from_slice(s).map_err(|e| {
Expand Down Expand Up @@ -226,7 +214,7 @@ impl PyModel {
/// Returns:
/// :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model
#[pyo3(text_signature = "(self)")]
fn get_trainer(&self, py: Python<'_>) -> PyResult<PyObject> {
fn get_trainer(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
PyTrainer::from(self.model.read().unwrap().get_trainer()).get_as_subtype(py)
}

Expand Down Expand Up @@ -945,7 +933,7 @@ mod test {

#[test]
fn get_subtype() {
Python::with_gil(|py| {
Python::attach(|py| {
let py_model = PyModel::from(BPE::default());
let py_bpe = py_model.get_as_subtype(py).unwrap();
assert_eq!("BPE", py_bpe.bind(py).get_type().qualname().unwrap());
Expand Down
Loading
Loading