diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index a26cb26bbb..dcaabd87b8 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -1597,11 +1597,10 @@ def _object_getattribute_lookaside(obj: Any, name: str): # 2) If `obj` has a metaclass, the dunder methods might be dynamic. # So for now we just fall back to the builtin `getattr` for these bedrock lookups. if DUNDER_PATTERN.match(name) or isinstance(uobj, (type, super)): - return ( - do_raise(AttributeError(f"'{type(uobj).__name__}' object has no attribute '{name}'")) - if (result := getattr(uobj, name, null)) is null - else result - ) + result = getattr(uobj, name, null) + if result is null: + return do_raise(AttributeError(f"'{type(uobj).__name__}' object has no attribute '{name}'")) + return result def lookup_descriptor_field(field_name): # Bypass the C portions of `property` so we don't break the `_interpret_call` chain @@ -1776,14 +1775,14 @@ def _getattr_lookaside(obj: Any, name: str, *maybe_default: Any): # `__getattr__` is only triggered if `__getattribute__` fails. # TODO: this should be `_interpret_call_with_unwrapping(getattr, obj, "__getattr__", null := object())`, but that would require multiple current exceptions. null = object() - obj_getattr = getattr(unwrap(obj), "__getattr__", null) + obj_getattr = getattr(type(unwrap(obj)), "__getattr__", null) if obj_getattr is not null: ctx.curexc = None assert callable(obj_getattr) if compilectx._with_provenance_tracking: obj_getattr = wrap_attribute(obj_getattr, obj, wrap_const("__getattr__")) - result = _interpret_call(obj_getattr, name) + result = _interpret_call(obj_getattr, obj, name) # which provenances to cache here? # result = wrap_attribute(unwrap(result), obj, name) diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index c003da8645..11f8e8e002 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -3577,3 +3577,16 @@ def fn(): jfn = thunder.jit(fn) out = jfn() assert out == ("list[int]", "dict[int, int]") + + +def test_getattr_type(jit): + m = torch.nn.Linear(5, 2) + + def fn(): + return bool(m), getattr(m, "weight", None) + + expected = fn() + actual = jit(fn)() + + assert actual[0] == expected[0] + assert actual[1] is expected[1]