Skip to content

Commit ac24932

Browse files
authored
Use underlying class's __new__ in MutableMappingWrapper.__new__ (#2514)
1 parent 316baa8 commit ac24932

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

thunder/core/interpreter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2513,7 +2513,8 @@ def __iter__(self):
25132513

25142514
class MutMappingWrapperMethods(WrappedValue):
25152515
def __new__(cls, /, *args, **kwds):
2516-
uvalue = unwrap(cls)()
2516+
ucls = unwrap(cls)
2517+
uvalue = ucls.__new__(ucls)
25172518
# todo: for subclasses, better record the call to the constructor
25182519
return wrap(uvalue, provenance=ProvenanceRecord(PseudoInst.NEW, inputs=[cls.provenance]))
25192520

thunder/tests/test_jit_general.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,3 +1730,17 @@ def _allocate_and_call_model_in_function():
17301730
ref = _allocate_and_call_model_in_function()
17311731
assert ref() is None
17321732
assert torch.cuda.memory_allocated() == memory_start
1733+
1734+
1735+
def test_dataclass_dict():
1736+
# diffusers model outputs are like this
1737+
from dataclasses import dataclass
1738+
1739+
@dataclass
1740+
class Foo(dict):
1741+
musthave: int
1742+
1743+
def fn():
1744+
return Foo(musthave=1)
1745+
1746+
assert fn() == thunder.jit(fn)()

0 commit comments

Comments
 (0)