diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 483865879..d51c83112 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -14,11 +14,11 @@ serde = { version = "1.0", features = ["rc", "derive"] } serde_json = "1.0" libc = "0.2" env_logger = "0.11" -pyo3 = { version = "0.25", features = ["abi3", "abi3-py39", "py-clone"] } -pyo3-async-runtimes = { version = "0.25", features = ["tokio-runtime"] } +pyo3 = { version = "0.26", features = ["abi3", "abi3-py39", "py-clone"] } +pyo3-async-runtimes = { version = "0.26", features = ["tokio-runtime"] } tokio = { version = "1.47.1", features = ["rt", "rt-multi-thread", "macros", "signal"] } once_cell = "1.19.0" -numpy = "0.25" +numpy = "0.26" ndarray = "0.16" itertools = "0.14" ahash = { version = "0.8.11", features = ["serde"] } @@ -28,7 +28,7 @@ path = "../../tokenizers" [dev-dependencies] tempfile = "3.10" -pyo3 = { version = "0.25", features = ["auto-initialize"] } +pyo3 = { version = "0.26", features = ["auto-initialize"] } [features] default = ["pyo3/extension-module"] diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 4b9367ef4..d0741e00d 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -40,51 +40,25 @@ impl PyDecoder { PyDecoder { decoder } } - pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult { + pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult> { let base = self.clone(); Ok(match &self.decoder { - PyDecoderWrapper::Custom(_) => Py::new(py, base)?.into_pyobject(py)?.into_any().into(), + PyDecoderWrapper::Custom(_) => Py::new(py, base)?.into_any(), PyDecoderWrapper::Wrapped(inner) => match &*inner.as_ref().read().unwrap() { - DecoderWrapper::Metaspace(_) => Py::new(py, (PyMetaspaceDec {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - DecoderWrapper::ByteFallback(_) => Py::new(py, (PyByteFallbackDec {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - DecoderWrapper::Strip(_) => Py::new(py, (PyStrip {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - DecoderWrapper::Fuse(_) => Py::new(py, (PyFuseDec {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - DecoderWrapper::Replace(_) => Py::new(py, (PyReplaceDec {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - DecoderWrapper::Sequence(_) => Py::new(py, (PySequenceDecoder {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + DecoderWrapper::Metaspace(_) => Py::new(py, (PyMetaspaceDec {}, base))?.into_any(), + DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?.into_any(), + DecoderWrapper::ByteFallback(_) => { + Py::new(py, (PyByteFallbackDec {}, base))?.into_any() + } + DecoderWrapper::Strip(_) => Py::new(py, (PyStrip {}, base))?.into_any(), + DecoderWrapper::Fuse(_) => Py::new(py, (PyFuseDec {}, base))?.into_any(), + DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_any(), + DecoderWrapper::Replace(_) => Py::new(py, (PyReplaceDec {}, base))?.into_any(), + DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_any(), + DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_any(), + DecoderWrapper::Sequence(_) => { + Py::new(py, (PySequenceDecoder {}, base))?.into_any() + } }, }) } @@ -99,12 +73,12 @@ impl Decoder for PyDecoder { #[pymethods] impl PyDecoder { #[staticmethod] - fn custom(decoder: PyObject) -> Self { + fn custom(decoder: Py) -> Self { let decoder = PyDecoderWrapper::Custom(Arc::new(RwLock::new(CustomDecoder::new(decoder)))); PyDecoder::new(decoder) } - fn __getstate__(&self, py: Python) -> PyResult { + fn __getstate__(&self, py: Python) -> PyResult> { let data = serde_json::to_string(&self.decoder).map_err(|e| { exceptions::PyException::new_err(format!( "Error while attempting to pickle Decoder: {e}" @@ -113,7 +87,7 @@ impl PyDecoder { Ok(PyBytes::new(py, data.as_bytes()).into()) } - fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + fn __setstate__(&mut self, py: Python, state: Py) -> PyResult<()> { match state.extract::<&[u8]>(py) { Ok(s) => { self.decoder = serde_json::from_slice(s).map_err(|e| { @@ -514,18 +488,18 @@ impl PySequenceDecoder { } pub(crate) struct CustomDecoder { - inner: PyObject, + inner: Py, } impl CustomDecoder { - pub(crate) fn new(inner: PyObject) -> Self { + pub(crate) fn new(inner: Py) -> Self { CustomDecoder { inner } } } impl Decoder for CustomDecoder { fn decode(&self, tokens: Vec) -> tk::Result { - Python::with_gil(|py| { + Python::attach(|py| { let decoded = self .inner .call_method(py, "decode", (tokens,), None)? @@ -535,7 +509,7 @@ impl Decoder for CustomDecoder { } fn decode_chain(&self, tokens: Vec) -> tk::Result> { - Python::with_gil(|py| { + Python::attach(|py| { let decoded = self .inner .call_method(py, "decode_chain", (tokens,), None)? @@ -708,7 +682,7 @@ mod test { #[test] fn get_subtype() { - Python::with_gil(|py| { + Python::attach(|py| { let py_dec = PyDecoder::new(Metaspace::default().into()); let py_meta = py_dec.get_as_subtype(py).unwrap(); assert_eq!("Metaspace", py_meta.bind(py).get_type().qualname().unwrap()); @@ -731,15 +705,9 @@ mod test { _ => panic!("Expected wrapped, not custom."), } - let obj = Python::with_gil(|py| { + let obj = Python::attach(|py| { let py_msp = PyDecoder::new(Metaspace::default().into()); - let obj: PyObject = Py::new(py, py_msp) - .unwrap() - .into_pyobject(py) - .unwrap() - .into_any() - .into(); - obj + Py::new(py, py_msp).unwrap().into_any() }); let py_seq = PyDecoderWrapper::Custom(Arc::new(RwLock::new(CustomDecoder::new(obj)))); assert!(serde_json::to_string(&py_seq).is_err()); diff --git a/bindings/python/src/encoding.rs b/bindings/python/src/encoding.rs index 899ee5bfb..3080541ca 100644 --- a/bindings/python/src/encoding.rs +++ b/bindings/python/src/encoding.rs @@ -30,7 +30,7 @@ impl PyEncoding { } } - fn __getstate__(&self, py: Python) -> PyResult { + fn __getstate__(&self, py: Python) -> PyResult> { let data = serde_json::to_string(&self.encoding).map_err(|e| { exceptions::PyException::new_err(format!( "Error while attempting to pickle Encoding: {e}" @@ -39,7 +39,7 @@ impl PyEncoding { Ok(PyBytes::new(py, data.as_bytes()).into()) } - fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + fn __setstate__(&mut self, py: Python, state: Py) -> PyResult<()> { match state.extract::<&[u8]>(py) { Ok(s) => { self.encoding = serde_json::from_slice(s).map_err(|e| { diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 4e79be447..b1e0f93b8 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -33,25 +33,13 @@ pub struct PyModel { } impl PyModel { - pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult { + pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult> { let base = self.clone(); Ok(match *self.model.as_ref().read().unwrap() { - ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - ModelWrapper::Unigram(_) => Py::new(py, (PyUnigram {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base))?.into_any(), + ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base))?.into_any(), + ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base))?.into_any(), + ModelWrapper::Unigram(_) => Py::new(py, (PyUnigram {}, base))?.into_any(), }) } } @@ -111,14 +99,14 @@ impl PyModel { } } - fn __getstate__(&self, py: Python) -> PyResult { + fn __getstate__(&self, py: Python) -> PyResult> { let data = serde_json::to_string(&self.model).map_err(|e| { exceptions::PyException::new_err(format!("Error while attempting to pickle Model: {e}")) })?; Ok(PyBytes::new(py, data.as_bytes()).into()) } - fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + fn __setstate__(&mut self, py: Python, state: Py) -> PyResult<()> { match state.extract::<&[u8]>(py) { Ok(s) => { self.model = serde_json::from_slice(s).map_err(|e| { @@ -226,7 +214,7 @@ impl PyModel { /// Returns: /// :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model #[pyo3(text_signature = "(self)")] - fn get_trainer(&self, py: Python<'_>) -> PyResult { + fn get_trainer(&self, py: Python<'_>) -> PyResult> { PyTrainer::from(self.model.read().unwrap().get_trainer()).get_as_subtype(py) } @@ -945,7 +933,7 @@ mod test { #[test] fn get_subtype() { - Python::with_gil(|py| { + Python::attach(|py| { let py_model = PyModel::from(BPE::default()); let py_bpe = py_model.get_as_subtype(py).unwrap(); assert_eq!("BPE", py_bpe.bind(py).get_type().qualname().unwrap()); diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index c5697f75b..783404199 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -52,80 +52,48 @@ impl PyNormalizer { pub(crate) fn new(normalizer: PyNormalizerTypeWrapper) -> Self { PyNormalizer { normalizer } } - pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult { + pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult> { let base = self.clone(); Ok(match self.normalizer { - PyNormalizerTypeWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + PyNormalizerTypeWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_any(), PyNormalizerTypeWrapper::Single(ref inner) => match &*inner .as_ref() .read() .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer"))? { PyNormalizerWrapper::Custom(_) => { - Py::new(py, base)?.into_pyobject(py)?.into_any().into() + Py::new(py, base)?.into_any() } PyNormalizerWrapper::Wrapped(ref inner) => match inner { NormalizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), NormalizerWrapper::BertNormalizer(_) => { - Py::new(py, (PyBertNormalizer {}, base))? - .into_pyobject(py)? - .into_any() - .into() + Py::new(py, (PyBertNormalizer {}, base))?.into_any() } NormalizerWrapper::StripNormalizer(_) => Py::new(py, (PyStrip {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), NormalizerWrapper::Prepend(_) => Py::new(py, (PyPrepend {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), NormalizerWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), NormalizerWrapper::StripAccents(_) => Py::new(py, (PyStripAccents {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), NormalizerWrapper::NFC(_) => Py::new(py, (PyNFC {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), NormalizerWrapper::NFD(_) => Py::new(py, (PyNFD {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), NormalizerWrapper::NFKC(_) => Py::new(py, (PyNFKC {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), NormalizerWrapper::NFKD(_) => Py::new(py, (PyNFKD {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), NormalizerWrapper::Lowercase(_) => Py::new(py, (PyLowercase {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), NormalizerWrapper::Precompiled(_) => Py::new(py, (PyPrecompiled {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), NormalizerWrapper::Replace(_) => Py::new(py, (PyReplace {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), NormalizerWrapper::Nmt(_) => Py::new(py, (PyNmt {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), }, }, }) @@ -141,13 +109,13 @@ impl Normalizer for PyNormalizer { #[pymethods] impl PyNormalizer { #[staticmethod] - fn custom(obj: PyObject) -> Self { + fn custom(obj: Py) -> Self { Self { normalizer: PyNormalizerWrapper::Custom(CustomNormalizer::new(obj)).into(), } } - fn __getstate__(&self, py: Python) -> PyResult { + fn __getstate__(&self, py: Python) -> PyResult> { let data = serde_json::to_string(&self.normalizer).map_err(|e| { exceptions::PyException::new_err(format!( "Error while attempting to pickle Normalizer: {e}" @@ -156,7 +124,7 @@ impl PyNormalizer { Ok(PyBytes::new(py, data.as_bytes()).into()) } - fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + fn __setstate__(&mut self, py: Python, state: Py) -> PyResult<()> { match state.extract::<&[u8]>(py) { Ok(s) => { self.normalizer = serde_json::from_slice(s).map_err(|e| { @@ -632,17 +600,17 @@ impl PyReplace { #[derive(Clone, Debug)] pub(crate) struct CustomNormalizer { - inner: PyObject, + inner: Py, } impl CustomNormalizer { - pub fn new(inner: PyObject) -> Self { + pub fn new(inner: Py) -> Self { Self { inner } } } impl tk::tokenizer::Normalizer for CustomNormalizer { fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> { - Python::with_gil(|py| { + Python::attach(|py| { let normalized = PyNormalizedStringRefMut::new(normalized); let py_normalized = self.inner.bind(py); py_normalized.call_method("normalize", (normalized.get().clone(),), None)?; @@ -826,7 +794,7 @@ mod test { #[test] fn get_subtype() { - Python::with_gil(|py| { + Python::attach(|py| { let py_norm = PyNormalizer::new(NFC.into()); let py_nfc = py_norm.get_as_subtype(py).unwrap(); assert_eq!("NFC", py_nfc.bind(py).get_type().qualname().unwrap()); diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 21681715a..7fdf7168d 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -48,13 +48,11 @@ impl PyPreTokenizer { PyPreTokenizer { pretok } } - pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult { + pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult> { let base = self.clone(); Ok(match self.pretok { PyPreTokenizerTypeWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), PyPreTokenizerTypeWrapper::Single(ref inner) => { match &*inner .as_ref() @@ -66,64 +64,40 @@ impl PyPreTokenizer { } PyPreTokenizerWrapper::Wrapped(inner) => match inner { PreTokenizerWrapper::Whitespace(_) => Py::new(py, (PyWhitespace {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), PreTokenizerWrapper::Split(_) => Py::new(py, (PySplit {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), PreTokenizerWrapper::Punctuation(_) => { Py::new(py, (PyPunctuation {}, base))? - .into_pyobject(py)? .into_any() - .into() } PreTokenizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), PreTokenizerWrapper::Metaspace(_) => Py::new(py, (PyMetaspace {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), PreTokenizerWrapper::Delimiter(_) => { Py::new(py, (PyCharDelimiterSplit {}, base))? - .into_pyobject(py)? .into_any() - .into() } PreTokenizerWrapper::WhitespaceSplit(_) => { Py::new(py, (PyWhitespaceSplit {}, base))? - .into_pyobject(py)? .into_any() - .into() } PreTokenizerWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), PreTokenizerWrapper::BertPreTokenizer(_) => { Py::new(py, (PyBertPreTokenizer {}, base))? - .into_pyobject(py)? .into_any() - .into() } PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), PreTokenizerWrapper::UnicodeScripts(_) => { Py::new(py, (PyUnicodeScripts {}, base))? - .into_pyobject(py)? .into_any() - .into() } PreTokenizerWrapper::FixedLength(_) => { Py::new(py, (PyFixedLength {}, base))? - .into_pyobject(py)? .into_any() - .into() } }, } @@ -141,13 +115,13 @@ impl PreTokenizer for PyPreTokenizer { #[pymethods] impl PyPreTokenizer { #[staticmethod] - fn custom(pretok: PyObject) -> Self { + fn custom(pretok: Py) -> Self { PyPreTokenizer { pretok: PyPreTokenizerWrapper::Custom(CustomPreTokenizer::new(pretok)).into(), } } - fn __getstate__(&self, py: Python) -> PyResult { + fn __getstate__(&self, py: Python) -> PyResult> { let data = serde_json::to_string(&self.pretok).map_err(|e| { exceptions::PyException::new_err(format!( "Error while attempting to pickle PreTokenizer: {e}" @@ -156,7 +130,7 @@ impl PyPreTokenizer { Ok(PyBytes::new(py, data.as_bytes()).into()) } - fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + fn __setstate__(&mut self, py: Python, state: Py) -> PyResult<()> { match state.extract::<&[u8]>(py) { Ok(s) => { let unpickled = serde_json::from_slice(s).map_err(|e| { @@ -813,18 +787,18 @@ impl PyUnicodeScripts { #[derive(Clone)] pub(crate) struct CustomPreTokenizer { - inner: PyObject, + inner: Py, } impl CustomPreTokenizer { - pub fn new(inner: PyObject) -> Self { + pub fn new(inner: Py) -> Self { Self { inner } } } impl tk::tokenizer::PreTokenizer for CustomPreTokenizer { fn pre_tokenize(&self, sentence: &mut PreTokenizedString) -> tk::Result<()> { - Python::with_gil(|py| { + Python::attach(|py| { let pretok = PyPreTokenizedStringRefMut::new(sentence); let py_pretok = self.inner.bind(py); py_pretok.call_method("pre_tokenize", (pretok.get().clone(),), None)?; @@ -1004,7 +978,7 @@ mod test { #[test] fn get_subtype() { - Python::with_gil(|py| { + Python::attach(|py| { let py_norm = PyPreTokenizer::new(Whitespace {}.into()); let py_wsp = py_norm.get_as_subtype(py).unwrap(); assert_eq!("Whitespace", py_wsp.bind(py).get_type().qualname().unwrap()); @@ -1041,15 +1015,9 @@ mod test { let py_ser = serde_json::to_string(&py_seq).unwrap(); assert_eq!(py_wrapper_ser, py_ser); - let obj = Python::with_gil(|py| { + let obj = Python::attach(|py| { let py_wsp = PyPreTokenizer::new(Whitespace {}.into()); - let obj: PyObject = Py::new(py, py_wsp) - .unwrap() - .into_pyobject(py) - .unwrap() - .into_any() - .into(); - obj + Py::new(py, py_wsp).unwrap().into_any() }); let py_seq: PyPreTokenizerWrapper = PyPreTokenizerWrapper::Custom(CustomPreTokenizer::new(obj)); diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 4a25c910c..8b924d98c 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -8,6 +8,7 @@ use pyo3::exceptions; use pyo3::exceptions::PyException; use pyo3::prelude::*; use pyo3::types::*; +use pyo3::IntoPyObjectExt; use serde::ser::SerializeStruct; use serde::Deserializer; use serde::Serializer; @@ -52,39 +53,26 @@ impl PyPostProcessor { PyPostProcessor { processor } } - pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult { + pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult> { let base = self.clone(); Ok( match self.processor { - PyPostProcessorTypeWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + PyPostProcessorTypeWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_any(), PyPostProcessorTypeWrapper::Single(ref inner) => { match &*inner.read().map_err(|_| { PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor") })? { PostProcessorWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), PostProcessorWrapper::Bert(_) => Py::new(py, (PyBertProcessing {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), PostProcessorWrapper::Roberta(_) => Py::new(py, (PyRobertaProcessing {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), PostProcessorWrapper::Template(_) => Py::new(py, (PyTemplateProcessing {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), PostProcessorWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + .into_any(), } } } @@ -110,7 +98,7 @@ impl PostProcessor for PyPostProcessor { #[pymethods] impl PyPostProcessor { - fn __getstate__(&self, py: Python) -> PyResult { + fn __getstate__(&self, py: Python) -> PyResult> { let data = serde_json::to_string(&self.processor).map_err(|e| { exceptions::PyException::new_err(format!( "Error while attempting to pickle PostProcessor: {e}" @@ -119,7 +107,7 @@ impl PyPostProcessor { Ok(PyBytes::new(py, data.as_bytes()).into()) } - fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + fn __setstate__(&mut self, py: Python, state: Py) -> PyResult<()> { match state.extract::<&[u8]>(py) { Ok(s) => { self.processor = serde_json::from_slice(s).map_err(|e| { @@ -346,7 +334,7 @@ impl PyBertProcessing { let (tok, id) = getter!(self_, Bert, get_sep_copy()); PyTuple::new( py, - Vec::::from([tok.into_pyobject(py)?.into(), id.into_pyobject(py)?.into()]), + Vec::>::from([tok.into_py_any(py)?, id.into_py_any(py)?]), ) } @@ -363,7 +351,7 @@ impl PyBertProcessing { let (tok, id) = getter!(self_, Bert, get_cls_copy()); PyTuple::new( py, - Vec::::from([tok.into_pyobject(py)?.into(), id.into_pyobject(py)?.into()]), + Vec::>::from([tok.into_py_any(py)?, id.into_py_any(py)?]), ) } @@ -427,7 +415,7 @@ impl PyRobertaProcessing { let (tok, id) = getter!(self_, Roberta, get_sep_copy()); PyTuple::new( py, - Vec::::from([tok.into_pyobject(py)?.into(), id.into_pyobject(py)?.into()]), + Vec::>::from([tok.into_py_any(py)?, id.into_py_any(py)?]), ) } @@ -444,7 +432,7 @@ impl PyRobertaProcessing { let (tok, id) = getter!(self_, Roberta, get_cls_copy()); PyTuple::new( py, - Vec::::from([tok.into_pyobject(py)?.into(), id.into_pyobject(py)?.into()]), + Vec::>::from([tok.into_py_any(py)?, id.into_py_any(py)?]), ) } @@ -835,7 +823,7 @@ mod test { #[test] fn get_subtype() { - Python::with_gil(|py| { + Python::attach(|py| { let py_proc = PyPostProcessor::new(PyPostProcessorTypeWrapper::Single(Arc::new( RwLock::new(BertProcessing::new(("SEP".into(), 0), ("CLS".into(), 1)).into()), ))); diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 0e27f594f..2a49c95c7 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -4,11 +4,11 @@ use std::hash::{Hash, Hasher}; use numpy::{npyffi, PyArray1, PyArrayMethods}; use pyo3::class::basic::CompareOp; -use pyo3::exceptions; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::*; use pyo3::IntoPyObject; +use pyo3::{exceptions, IntoPyObjectExt}; use tk::models::bpe::BPE; use tk::tokenizer::{ Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl, @@ -156,7 +156,7 @@ impl PyAddedToken { self.as_pydict(py) } - fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + fn __setstate__(&mut self, py: Python, state: Py) -> PyResult<()> { match state.downcast_bound::(py) { Ok(state) => { for (key, value) in state { @@ -248,7 +248,7 @@ impl PyAddedToken { fn __richcmp__(&self, other: Py, op: CompareOp) -> bool { use CompareOp::*; - Python::with_gil(|py| match op { + Python::attach(|py| match op { Lt | Le | Gt | Ge => false, Eq => self.get_token() == other.borrow(py).get_token(), Ne => self.get_token() != other.borrow(py).get_token(), @@ -331,7 +331,7 @@ impl FromPyObject<'_> for PyArrayUnicode { // elsize as isize / alignment as isize, // ); // let py = ob.py(); - // let obj = PyObject::from_owned_ptr(py, unicode); + // let obj = Py::from_owned_ptr(py, unicode); // let s = obj.downcast_bound::(py)?; // Ok(s.to_string_lossy().trim_matches(char::from(0)).to_owned()) }) @@ -351,7 +351,7 @@ struct PyArrayStr(Vec); impl FromPyObject<'_> for PyArrayStr { fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { - let array = ob.downcast::>()?; + let array = ob.downcast::>>()?; let seq = array .readonly() .as_array() @@ -595,7 +595,7 @@ impl PyTokenizer { PyTokenizer::from_model(model.clone()) } - fn __getstate__(&self, py: Python) -> PyResult { + fn __getstate__(&self, py: Python) -> PyResult> { let data = serde_json::to_string(&self.tokenizer).map_err(|e| { exceptions::PyException::new_err(format!( "Error while attempting to pickle Tokenizer: {e}" @@ -604,7 +604,7 @@ impl PyTokenizer { Ok(PyBytes::new(py, data.as_bytes()).into()) } - fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + fn __setstate__(&mut self, py: Python, state: Py) -> PyResult<()> { match state.extract::<&[u8]>(py) { Ok(s) => { self.tokenizer = serde_json::from_slice(s).map_err(|e| { @@ -619,7 +619,7 @@ impl PyTokenizer { } fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult> { - let model: PyObject = PyModel::from(BPE::default()) + let model: Py = PyModel::from(BPE::default()) .into_pyobject(py)? .into_any() .into(); @@ -700,7 +700,7 @@ impl PyTokenizer { revision: String, token: Option, ) -> PyResult { - let path = Python::with_gil(|py| -> PyResult { + let path = Python::attach(|py| -> PyResult { let huggingface_hub = PyModule::import(py, intern!(py, "huggingface_hub"))?; let hf_hub_download = huggingface_hub.getattr(intern!(py, "hf_hub_download"))?; let kwargs = [ @@ -1146,24 +1146,15 @@ impl PyTokenizer { let tokenizer = self.tokenizer.clone(); let rt = crate::TOKIO_RUNTIME.clone(); - let fut = py.allow_threads(|| async move { - let result = rt - .spawn_blocking(move || { - tokenizer - .encode(input, add_special_tokens) - .map(PyEncoding::from) - }) - .await - .unwrap(); - - // Convert to a Python object directly - match result { - Ok(encoding) => Python::with_gil(|py| { - let obj: PyObject = encoding.into_pyobject(py)?.into_any().unbind(); - Ok(obj) - }), - Err(e) => Err(exceptions::PyException::new_err(e.to_string())), - } + let fut = py.detach(|| async move { + rt.spawn_blocking(move || { + tokenizer + .encode(input, add_special_tokens) + .map(PyEncoding::from) + }) + .await + .unwrap() + .map_err(|e| exceptions::PyException::new_err(e.to_string())) }); pyo3_async_runtimes::tokio::future_into_py(py, fut) @@ -1220,7 +1211,7 @@ impl PyTokenizer { }; items.push(item); } - py.allow_threads(|| { + py.detach(|| { ToPyResult( self.tokenizer .encode_batch_char_offsets(items, add_special_tokens) @@ -1276,27 +1267,15 @@ impl PyTokenizer { let tokenizer = self.tokenizer.clone(); let rt = crate::TOKIO_RUNTIME.clone(); - let fut = py.allow_threads(|| async move { - let result = rt - .spawn_blocking(move || { - tokenizer - .encode_batch_char_offsets(owned_items, add_special_tokens) - .map(|encs| encs.into_iter().map(PyEncoding::from).collect::>()) - }) - .await - .unwrap(); - - // Convert to a Python object directly rather than going through ToPyResult - match result { - Ok(encodings) => Python::with_gil(|py| { - let obj: PyObject = encodings - .into_pyobject(py)? // Vec -> Bound<'py, PyList> - .into_any() // Bound<'py, PyAny> - .unbind(); // Py a.k.a. PyObject (owned) - Ok(obj) - }), - Err(e) => Err(exceptions::PyException::new_err(e.to_string())), - } + let fut = py.detach(|| async move { + rt.spawn_blocking(move || { + tokenizer + .encode_batch_char_offsets(owned_items, add_special_tokens) + .map(|encs| encs.into_iter().map(PyEncoding::from).collect::>()) + }) + .await + .unwrap() + .map_err(|e| exceptions::PyException::new_err(e.to_string())) }); pyo3_async_runtimes::tokio::future_into_py(py, fut) @@ -1351,7 +1330,7 @@ impl PyTokenizer { }; items.push(item); } - py.allow_threads(|| { + py.detach(|| { ToPyResult( self.tokenizer .encode_batch_fast(items, add_special_tokens) @@ -1406,7 +1385,7 @@ impl PyTokenizer { let tokenizer = self.tokenizer.clone(); let rt = crate::TOKIO_RUNTIME.clone(); - let fut = py.allow_threads(|| async move { + let fut = py.detach(|| async move { let result = rt .spawn_blocking(move || { tokenizer @@ -1418,13 +1397,7 @@ impl PyTokenizer { // Convert to a Python object directly rather than going through ToPyResult match result { - Ok(encodings) => Python::with_gil(|py| { - let obj: PyObject = encodings - .into_pyobject(py)? // Vec -> Bound<'py, PyList> - .into_any() // Bound<'py, PyAny> - .unbind(); // Py a.k.a. PyObject (owned) - Ok(obj) - }), + Ok(encodings) => Python::attach(|py| encodings.into_py_any(py)), Err(e) => Err(exceptions::PyException::new_err(e.to_string())), } }); @@ -1470,7 +1443,7 @@ impl PyTokenizer { sequences: Vec>, skip_special_tokens: bool, ) -> PyResult> { - py.allow_threads(|| { + py.detach(|| { let slices = sequences.iter().map(|v| &v[..]).collect::>(); ToPyResult(self.tokenizer.decode_batch(&slices, skip_special_tokens)).into() }) @@ -1498,25 +1471,14 @@ impl PyTokenizer { let tokenizer = self.tokenizer.clone(); let rt = crate::TOKIO_RUNTIME.clone(); - let fut = py.allow_threads(|| async move { - let result = rt - .spawn_blocking(move || { - let slices = sequences.iter().map(|v| &v[..]).collect::>(); - tokenizer.decode_batch(&slices, skip_special_tokens) - }) - .await - .unwrap(); - - match result { - Ok(decoded_strings) => Python::with_gil(|py| { - let obj: PyObject = decoded_strings - .into_pyobject(py)? // Vec -> Bound<'py, PyList> - .into_any() // Bound<'py, PyAny> - .unbind(); // Py a.k.a. PyObject (owned) - Ok(obj) - }), - Err(e) => Err(exceptions::PyException::new_err(e.to_string())), - } + let fut = py.detach(|| async move { + rt.spawn_blocking(move || { + let slices = sequences.iter().map(|v| &v[..]).collect::>(); + tokenizer.decode_batch(&slices, skip_special_tokens) + }) + .await + .unwrap() + .map_err(|e| exceptions::PyException::new_err(e.to_string())) }); pyo3_async_runtimes::tokio::future_into_py(py, fut) @@ -1653,8 +1615,8 @@ impl PyTokenizer { fn train(&mut self, files: Vec, trainer: Option<&mut PyTrainer>) -> PyResult<()> { let mut trainer = trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone()); - Python::with_gil(|py| { - py.allow_threads(|| { + Python::attach(|py| { + py.detach(|| { ToPyResult( self.tokenizer .train_from_files(&mut trainer, files) @@ -1718,7 +1680,7 @@ impl PyTokenizer { 256, )?; - py.allow_threads(|| { + py.detach(|| { ResultShunt::process(buffered_iter, |iter| { self.tokenizer .train(&mut trainer, MaybeSizedIterator::new(iter, length)) @@ -1772,7 +1734,7 @@ impl PyTokenizer { /// The :class:`~tokenizers.models.Model` in use by the Tokenizer #[getter] - fn get_model(&self, py: Python<'_>) -> PyResult { + fn get_model(&self, py: Python<'_>) -> PyResult> { self.tokenizer.get_model().get_as_subtype(py) } @@ -1784,7 +1746,7 @@ impl PyTokenizer { /// The `optional` :class:`~tokenizers.normalizers.Normalizer` in use by the Tokenizer #[getter] - fn get_normalizer(&self, py: Python<'_>) -> PyResult { + fn get_normalizer(&self, py: Python<'_>) -> PyResult> { if let Some(n) = self.tokenizer.get_normalizer() { n.get_as_subtype(py) } else { @@ -1801,7 +1763,7 @@ impl PyTokenizer { /// The `optional` :class:`~tokenizers.pre_tokenizers.PreTokenizer` in use by the Tokenizer #[getter] - fn get_pre_tokenizer(&self, py: Python<'_>) -> PyResult { + fn get_pre_tokenizer(&self, py: Python<'_>) -> PyResult> { if let Some(pt) = self.tokenizer.get_pre_tokenizer() { pt.get_as_subtype(py) } else { @@ -1818,7 +1780,7 @@ impl PyTokenizer { /// The `optional` :class:`~tokenizers.processors.PostProcessor` in use by the Tokenizer #[getter] - fn get_post_processor(&self, py: Python<'_>) -> PyResult { + fn get_post_processor(&self, py: Python<'_>) -> PyResult> { if let Some(n) = self.tokenizer.get_post_processor() { n.get_as_subtype(py) } else { @@ -1835,7 +1797,7 @@ impl PyTokenizer { /// The `optional` :class:`~tokenizers.decoders.Decoder` in use by the Tokenizer #[getter] - fn get_decoder(&self, py: Python<'_>) -> PyResult { + fn get_decoder(&self, py: Python<'_>) -> PyResult> { if let Some(dec) = self.tokenizer.get_decoder() { dec.get_as_subtype(py) } else { diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index 30786862e..53415fff0 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -26,31 +26,25 @@ impl PyTrainer { pub(crate) fn new(trainer: Arc>) -> Self { PyTrainer { trainer } } - pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult { + pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult> { let base = self.clone(); Ok(match *self.trainer.as_ref().read().unwrap() { - TrainerWrapper::BpeTrainer(_) => Py::new(py, (PyBpeTrainer {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - TrainerWrapper::WordPieceTrainer(_) => Py::new(py, (PyWordPieceTrainer {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - TrainerWrapper::WordLevelTrainer(_) => Py::new(py, (PyWordLevelTrainer {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - TrainerWrapper::UnigramTrainer(_) => Py::new(py, (PyUnigramTrainer {}, base))? - .into_pyobject(py)? - .into_any() - .into(), + TrainerWrapper::BpeTrainer(_) => Py::new(py, (PyBpeTrainer {}, base))?.into_any(), + TrainerWrapper::WordPieceTrainer(_) => { + Py::new(py, (PyWordPieceTrainer {}, base))?.into_any() + } + TrainerWrapper::WordLevelTrainer(_) => { + Py::new(py, (PyWordLevelTrainer {}, base))?.into_any() + } + TrainerWrapper::UnigramTrainer(_) => { + Py::new(py, (PyUnigramTrainer {}, base))?.into_any() + } }) } } #[pymethods] impl PyTrainer { - fn __getstate__(&self, py: Python) -> PyResult { + fn __getstate__(&self, py: Python) -> PyResult> { let data = serde_json::to_string(&self.trainer).map_err(|e| { exceptions::PyException::new_err(format!( "Error while attempting to pickle PyTrainer: {e}" @@ -59,7 +53,7 @@ impl PyTrainer { Ok(PyBytes::new(py, data.as_bytes()).into()) } - fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + fn __setstate__(&mut self, py: Python, state: Py) -> PyResult<()> { match state.extract::<&[u8]>(py) { Ok(s) => { let unpickled = serde_json::from_slice(s).map_err(|e| { @@ -909,7 +903,7 @@ mod tests { #[test] fn get_subtype() { - Python::with_gil(|py| { + Python::attach(|py| { let py_trainer = PyTrainer::new(Arc::new(RwLock::new(BpeTrainer::default().into()))); let py_bpe = py_trainer.get_as_subtype(py).unwrap(); assert_eq!("BpeTrainer", py_bpe.bind(py).get_type().qualname().unwrap()); diff --git a/bindings/python/src/utils/iterators.rs b/bindings/python/src/utils/iterators.rs index d619b93d2..3d985ffb5 100644 --- a/bindings/python/src/utils/iterators.rs +++ b/bindings/python/src/utils/iterators.rs @@ -82,7 +82,7 @@ where return Ok(()); } - Python::with_gil(|py| loop { + Python::attach(|py| loop { if self.buffer.len() >= self.size { return Ok(()); } diff --git a/bindings/python/src/utils/normalization.rs b/bindings/python/src/utils/normalization.rs index 191db7ed2..75cbc2681 100644 --- a/bindings/python/src/utils/normalization.rs +++ b/bindings/python/src/utils/normalization.rs @@ -28,9 +28,7 @@ impl Pattern for PyPattern { s.find_matches(inside) } } - PyPattern::Regex(r) => { - Python::with_gil(|py| (&r.borrow(py).inner).find_matches(inside)) - } + PyPattern::Regex(r) => Python::attach(|py| (&r.borrow(py).inner).find_matches(inside)), } } } @@ -39,7 +37,7 @@ impl From for tk::normalizers::replace::ReplacePattern { fn from(pattern: PyPattern) -> Self { match pattern { PyPattern::Str(s) => Self::String(s.to_owned()), - PyPattern::Regex(r) => Python::with_gil(|py| Self::Regex(r.borrow(py).pattern.clone())), + PyPattern::Regex(r) => Python::attach(|py| Self::Regex(r.borrow(py).pattern.clone())), } } } @@ -48,7 +46,7 @@ impl From for tk::pre_tokenizers::split::SplitPattern { fn from(pattern: PyPattern) -> Self { match pattern { PyPattern::Str(s) => Self::String(s.to_owned()), - PyPattern::Regex(r) => Python::with_gil(|py| Self::Regex(r.borrow(py).pattern.clone())), + PyPattern::Regex(r) => Python::attach(|py| Self::Regex(r.borrow(py).pattern.clone())), } } }