Skip to content

Commit 620ebd0

Browse files
pytorchbotmlazos
andauthored
[Dynamo] Use proper sources for constructing dataclass defaults (pytorch#158689)
[Dynamo] Use proper sources for constructing dataclass defaults (pytorch#157993) Partially fixes pytorch#154009 Pull Request resolved: pytorch#157993 Approved by: https://github.com/williamwen42, https://github.com/anijain2305 (cherry picked from commit 89850bb) Co-authored-by: Michael Lazos <[email protected]>
1 parent 5d52613 commit 620ebd0

File tree

5 files changed

+62
-2
lines changed

5 files changed

+62
-2
lines changed

test/dynamo/test_misc.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10411,6 +10411,26 @@ def fn(x, y):
1041110411
actual = fn_opt(*inps)
1041210412
expected = fn(*inps)
1041310413

10414+
def test_nested_dataclass_reconstruct(self):
10415+
@dataclasses.dataclass(frozen=True)
10416+
class NestedDataClass:
10417+
x: int = 2
10418+
10419+
@dataclasses.dataclass(frozen=True)
10420+
class TestDataClass:
10421+
y: torch.Tensor
10422+
ndc: NestedDataClass = NestedDataClass()
10423+
10424+
def fn(y):
10425+
dc = TestDataClass(y)
10426+
z = dc.y + dc.ndc.x
10427+
return z, dc
10428+
10429+
fn_opt = torch.compile()(fn)
10430+
inps = (torch.ones(2, 2),)
10431+
actual = fn_opt(*inps)
10432+
expected = fn(*inps)
10433+
1041410434
def test_frozen_dataclass_default_value(self):
1041510435
@dataclasses.dataclass(frozen=True)
1041610436
class TestDataClass:

torch/_dynamo/guards.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
ChainedSource,
105105
ConstantSource,
106106
ConstDictKeySource,
107+
DataclassFieldsSource,
107108
DefaultsSource,
108109
DictGetItemSource,
109110
DictSubclassGetItemSource,
@@ -144,6 +145,7 @@
144145
from .utils import (
145146
builtin_dict_keys,
146147
common_constant_types,
148+
dataclass_fields,
147149
dict_keys,
148150
get_custom_getattr,
149151
get_torch_function_mode_stack,
@@ -449,6 +451,7 @@ def _get_closure_vars():
449451
"___tuple_iterator_len": tuple_iterator_len,
450452
"___normalize_range_iter": normalize_range_iter,
451453
"___tuple_iterator_getitem": tuple_iterator_getitem,
454+
"___dataclass_fields": dataclass_fields,
452455
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
453456
"__math_isnan": math.isnan,
454457
"__numpy_isnan": None if np is None else np.isnan,
@@ -1301,6 +1304,14 @@ def get_guard_manager_from_source(self, source):
13011304
example_value=example_value,
13021305
guard_manager_enum=guard_manager_enum,
13031306
)
1307+
elif istype(source, DataclassFieldsSource):
1308+
assert base_guard_manager
1309+
out = base_guard_manager.lambda_manager(
1310+
python_lambda=lambda x: dataclass_fields(x),
1311+
source=source_name,
1312+
example_value=example_value,
1313+
guard_manager_enum=guard_manager_enum,
1314+
)
13041315
else:
13051316
raise AssertionError(
13061317
f"missing guard manager builder {source} - {source.name()}"

torch/_dynamo/source.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,22 @@ def name(self):
723723
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
724724

725725

726+
@dataclasses.dataclass(frozen=True)
727+
class DataclassFieldsSource(ChainedSource):
728+
def reconstruct(self, codegen: "PyCodegen"):
729+
codegen.add_push_null(
730+
lambda: codegen.load_import_from(utils.__name__, "dataclass_fields")
731+
)
732+
codegen(self.base)
733+
codegen.extend_output(create_call_function(1, False))
734+
735+
def guard_source(self):
736+
return self.base.guard_source()
737+
738+
def name(self):
739+
return f"___dataclass_fields({self.base.name()})"
740+
741+
726742
@dataclasses.dataclass(frozen=True)
727743
class TypeSource(ChainedSource):
728744
def __post_init__(self):

torch/_dynamo/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2544,6 +2544,10 @@ def tuple_iterator_getitem(it, index):
25442544
return obj[start + index]
25452545

25462546

2547+
def dataclass_fields(cls):
2548+
return torch._dynamo.disable(dataclasses.fields)(cls)
2549+
2550+
25472551
iter_next = next
25482552

25492553

torch/_dynamo/variables/user_defined.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import enum
3030
import functools
3131
import inspect
32+
import itertools
3233
import random
3334
import sys
3435
import threading
@@ -56,6 +57,7 @@
5657
from ..source import (
5758
AttrSource,
5859
CallFunctionNoArgsSource,
60+
DataclassFieldsSource,
5961
GetItemSource,
6062
RandomValueSource,
6163
TypeSource,
@@ -610,11 +612,12 @@ def call_function(
610612
return SizeVariable(tup.items)
611613
elif is_frozen_dataclass(self.value) and self.is_standard_new():
612614
fields = dataclasses.fields(self.value)
615+
fields_source = DataclassFieldsSource(self.source)
613616
items = list(args)
614617
items.extend([None] * (len(fields) - len(items)))
615618

616619
default_kwargs = {}
617-
for field, var_tracker in zip(fields, items):
620+
for ind, field, var_tracker in zip(itertools.count(), fields, items):
618621
if var_tracker is None:
619622
if field.name in kwargs:
620623
var_tracker = kwargs[field.name]
@@ -623,7 +626,13 @@ def call_function(
623626
continue
624627

625628
if field.default is not dataclasses.MISSING:
626-
var_tracker = VariableTracker.build(tx, field.default)
629+
var_tracker = VariableTracker.build(
630+
tx,
631+
field.default,
632+
source=AttrSource(
633+
GetItemSource(fields_source, ind), "default"
634+
),
635+
)
627636
elif field.default_factory is not dataclasses.MISSING:
628637
factory_fn = VariableTracker.build(
629638
tx, field.default_factory

0 commit comments

Comments
 (0)