Skip to content

Commit baf3cf8

Browse files
committed
post4: MagicMock handling fix
- Clarify that __module__ check must come before PyObject_HasAttrString - MagicMock objects auto-respond True to HasAttr, causing false positives - Module name check gates the HasAttr calls to prevent this
1 parent 211329e commit baf3cf8

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/serialize/obtype.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ pub(crate) fn pyobject_to_obtype_unlikely(
121121
#[cold]
122122
fn is_pytorch_tensor(ob_type: *mut PyTypeObject) -> bool {
123123
unsafe {
124-
// Check if the type's __module__ starts with "torch" first,
125-
// to avoid calling HasAttr on types like MagicMock
124+
// Check __module__ starts with "torch" BEFORE calling HasAttr.
125+
// This is important because MagicMock objects auto-respond True to
126+
// HasAttr calls, which would cause false positives.
126127
let ob_type_ptr = ob_type.cast::<crate::ffi::PyObject>();
127128
let module = crate::ffi::PyObject_GetAttrString(ob_type_ptr, c"__module__".as_ptr());
128129
if module.is_null() {
@@ -141,7 +142,7 @@ fn is_pytorch_tensor(ob_type: *mut PyTypeObject) -> bool {
141142
return false;
142143
}
143144

144-
// Verify it has the expected tensor methods
145+
// Only after confirming torch module, verify tensor methods
145146
PyObject_HasAttrString(ob_type_ptr, c"numpy".as_ptr()) == 1
146147
&& PyObject_HasAttrString(ob_type_ptr, c"cpu".as_ptr()) == 1
147148
&& PyObject_HasAttrString(ob_type_ptr, c"detach".as_ptr()) == 1

0 commit comments

Comments
 (0)