Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions thunder/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,11 +1597,10 @@ def _object_getattribute_lookaside(obj: Any, name: str):
# 2) If `obj` has a metaclass, the dunder methods might be dynamic.
# So for now we just fall back to the builtin `getattr` for these bedrock lookups.
if DUNDER_PATTERN.match(name) or isinstance(uobj, (type, super)):
return (
do_raise(AttributeError(f"'{type(uobj).__name__}' object has no attribute '{name}'"))
if (result := getattr(uobj, name, null)) is null
else result
)
result = getattr(uobj, name, null)
if result is null:
return do_raise(AttributeError(f"'{type(uobj).__name__}' object has no attribute '{name}'"))
return result

def lookup_descriptor_field(field_name):
# Bypass the C portions of `property` so we don't break the `_interpret_call` chain
Expand Down Expand Up @@ -1776,14 +1775,14 @@ def _getattr_lookaside(obj: Any, name: str, *maybe_default: Any):
# `__getattr__` is only triggered if `__getattribute__` fails.
# TODO: this should be `_interpret_call_with_unwrapping(getattr, obj, "__getattr__", null := object())`, but that would require multiple current exceptions.
null = object()
obj_getattr = getattr(unwrap(obj), "__getattr__", null)
obj_getattr = getattr(type(unwrap(obj)), "__getattr__", null)

if obj_getattr is not null:
ctx.curexc = None
assert callable(obj_getattr)
if compilectx._with_provenance_tracking:
obj_getattr = wrap_attribute(obj_getattr, obj, wrap_const("__getattr__"))
result = _interpret_call(obj_getattr, name)
result = _interpret_call(obj_getattr, obj, name)
# which provenances to cache here?
# result = wrap_attribute(unwrap(result), obj, name)

Expand Down
13 changes: 13 additions & 0 deletions thunder/tests/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3577,3 +3577,16 @@ def fn():
jfn = thunder.jit(fn)
out = jfn()
assert out == ("list[int]", "dict[int, int]")


def test_getattr_type(jit):
m = torch.nn.Linear(5, 2)

def fn():
return bool(m), getattr(m, "weight", None)

expected = fn()
actual = jit(fn)()

assert actual[0] == expected[0]
assert actual[1] is expected[1]
Loading