Skip to content

Commit 211329e

Browse files
committed
post3: Deferred torch import (duck-typing via module name)
- Remove PYTORCH_TENSOR_TYPE static and look_up_pytorch_type() - Replace type-based pytorch detection with duck-typing - Check __module__ starts with 'torch' then verify numpy/cpu/detach methods - Export PyObject_HasAttrString and PyObject_GetAttrString from ffi - No longer imports torch at startup, reducing overhead for non-torch processes
1 parent c735d17 commit 211329e

File tree

3 files changed

+35
-26
lines changed

3 files changed

+35
-26
lines changed

src/ffi/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ pub(crate) use pyo3_ffi::{
7474
PyMemoryView_Type, PyMethodDef, PyMethodDefPointer, PyModule_AddIntConstant,
7575
PyModule_AddObject, PyModuleDef, PyModuleDef_HEAD_INIT, PyModuleDef_Init, PyModuleDef_Slot,
7676
PyObject, PyObject_CallFunctionObjArgs, PyObject_CallMethodObjArgs, PyObject_GenericGetDict,
77-
PyObject_GetAttr, PyObject_HasAttr, PyObject_Hash, PyObject_Vectorcall, PyTuple_New,
77+
PyObject_GetAttr, PyObject_GetAttrString, PyObject_HasAttr, PyObject_HasAttrString,
78+
PyObject_Hash, PyObject_Vectorcall, PyTuple_New,
7879
PyTuple_Type, PyTupleObject, PyType_Ready, PyType_Type, PyTypeObject, PyUnicode_AsUTF8AndSize,
7980
PyUnicode_FromStringAndSize, PyUnicode_InternFromString, PyUnicode_New, PyUnicode_Type,
8081
PyVarObject, PyVectorcall_NARGS,

src/serialize/obtype.rs

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
22
// Copyright ijl (2020-2025), Aviram Hassan (2020)
33

4+
use crate::ffi::{PyObject_HasAttrString, PyStrRef, PyTypeObject};
45
use crate::opt::{
56
Opt, PASSTHROUGH_DATACLASS, PASSTHROUGH_DATETIME, PASSTHROUGH_SUBCLASS, SERIALIZE_NUMPY,
67
};
78
use crate::serialize::per_type::{is_numpy_array, is_numpy_scalar};
89
use crate::typeref::{
910
BOOL_TYPE, DATACLASS_FIELDS_STR, DATE_TYPE, DATETIME_TYPE, DICT_TYPE, ENUM_TYPE, FLOAT_TYPE,
10-
FRAGMENT_TYPE, INT_TYPE, LIST_TYPE, NONE_TYPE, PYTORCH_TENSOR_TYPE, STR_TYPE, TIME_TYPE,
11-
TUPLE_TYPE, UUID_TYPE,
11+
FRAGMENT_TYPE, INT_TYPE, LIST_TYPE, NONE_TYPE, STR_TYPE, TIME_TYPE, TUPLE_TYPE, UUID_TYPE,
1212
};
1313

1414
#[repr(u32)]
@@ -110,10 +110,40 @@ pub(crate) fn pyobject_to_obtype_unlikely(
110110
return ObType::NumpyScalar;
111111
} else if is_numpy_array(ob_type) {
112112
return ObType::NumpyArray;
113-
} else if is_class_by_type!(ob_type, PYTORCH_TENSOR_TYPE) {
113+
} else if is_pytorch_tensor(ob_type) {
114114
return ObType::PyTorchTensor;
115115
}
116116
}
117117

118118
ObType::Unknown
119119
}
120+
121+
#[cold]
122+
fn is_pytorch_tensor(ob_type: *mut PyTypeObject) -> bool {
123+
unsafe {
124+
// Check if the type's __module__ starts with "torch" first,
125+
// to avoid calling HasAttr on types like MagicMock
126+
let ob_type_ptr = ob_type.cast::<crate::ffi::PyObject>();
127+
let module = crate::ffi::PyObject_GetAttrString(ob_type_ptr, c"__module__".as_ptr());
128+
if module.is_null() {
129+
crate::ffi::PyErr_Clear();
130+
return false;
131+
}
132+
let starts_with_torch = match PyStrRef::from_ptr(module) {
133+
Ok(s) => match s.as_str() {
134+
Some(s) => s.starts_with("torch"),
135+
None => false,
136+
},
137+
Err(_) => false,
138+
};
139+
ffi!(Py_DECREF(module));
140+
if !starts_with_torch {
141+
return false;
142+
}
143+
144+
// Verify it has the expected tensor methods
145+
PyObject_HasAttrString(ob_type_ptr, c"numpy".as_ptr()) == 1
146+
&& PyObject_HasAttrString(ob_type_ptr, c"cpu".as_ptr()) == 1
147+
&& PyObject_HasAttrString(ob_type_ptr, c"detach".as_ptr()) == 1
148+
}
149+
}

src/typeref.rs

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ pub(crate) static mut FIELD_TYPE: *mut PyTypeObject = null_mut();
4040
pub(crate) static mut FRAGMENT_TYPE: *mut PyTypeObject = null_mut();
4141

4242
pub(crate) static mut ZONEINFO_TYPE: *mut PyTypeObject = null_mut();
43-
pub(crate) static mut PYTORCH_TENSOR_TYPE: *mut PyTypeObject = null_mut();
4443

4544
pub(crate) static mut UTCOFFSET_METHOD_STR: *mut PyObject = null_mut();
4645
pub(crate) static mut NORMALIZE_METHOD_STR: *mut PyObject = null_mut();
@@ -175,26 +174,6 @@ fn _init_typerefs_impl() -> bool {
175174
true
176175
}
177176

178-
#[cold]
179-
#[cfg_attr(feature = "optimize", optimize(size))]
180-
pub(crate) fn look_up_pytorch_type() {
181-
unsafe {
182-
let torch = PyImport_ImportModule(c"torch".as_ptr());
183-
if torch.is_null() {
184-
PyErr_Clear();
185-
return;
186-
}
187-
let torch_module_dict = PyObject_GenericGetDict(torch, null_mut());
188-
let tensor_type =
189-
PyMapping_GetItemString(torch_module_dict, c"Tensor".as_ptr()).cast::<PyTypeObject>();
190-
Py_XDECREF(torch_module_dict);
191-
Py_XDECREF(torch);
192-
if !tensor_type.is_null() {
193-
PYTORCH_TENSOR_TYPE = tensor_type;
194-
}
195-
}
196-
}
197-
198177
pub(crate) struct NumpyTypes {
199178
pub array: *mut PyTypeObject,
200179
pub float64: *mut PyTypeObject,
@@ -253,7 +232,6 @@ pub(crate) fn load_numpy_types() -> Box<Option<NonNull<NumpyTypes>>> {
253232
});
254233
Py_XDECREF(numpy_module_dict);
255234
Py_XDECREF(numpy);
256-
look_up_pytorch_type();
257235
Box::new(Some(nonnull!(Box::<NumpyTypes>::into_raw(types))))
258236
}
259237
}

0 commit comments

Comments
 (0)