Skip to content

Commit c055ebe

Browse files
morrison-turnanskypytorchmergebot
authored andcommitted
Change NamedTupleVariable implementation to subclass UserDefinedTupleVariable (pytorch#167468)
Continuation of work from previous PR, see link for context pytorch#161645 (comment) I think this PR is a step in that direction. There is probably some room for simplification. At a high level, the new class NamedTupleVariable handles methods that branch on structseq or the more dynamic subclasses of namedtuple, and falls back to UserDefinedTupleVariable otherwise. Please let me know what you think. @StrongerXi Pull Request resolved: pytorch#167468 Approved by: https://github.com/guilhermeleobas, https://github.com/StrongerXi, https://github.com/mlazos
1 parent bc8da63 commit c055ebe

File tree

4 files changed

+85
-149
lines changed

4 files changed

+85
-149
lines changed

test/dynamo/test_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2053,7 +2053,7 @@ def test_namedtuple_defaults(a, b):
20532053
return mytuple(tmp.x, tmp[1], tmp.xy + b)
20542054

20552055
@make_test
2056-
def test_namedtuple_replace(a, b):
2056+
def test_namedtuple_replace_1(a, b):
20572057
mytuple = collections.namedtuple("mytuple", ["x", "y"])
20582058
t = mytuple(a, b)
20592059
t._replace(x=b)
@@ -2109,7 +2109,7 @@ def test_namedtuple_user_methods(a, b):
21092109
return mytuple.add(), mytuple.static_method(), mytuple.class_method()
21102110

21112111
@make_test
2112-
def test_namedtuple_replace(a, b):
2112+
def test_namedtuple_replace_2(a, b):
21132113
mytuple = FunctionTests.MyNamedTuple(a, b)
21142114
replaced = mytuple._replace(first=b)
21152115
return mytuple.first + mytuple.second + replaced.first + replaced.second

torch/_dynamo/variables/dicts.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
)
4545
from .base import ValueMutationNew, VariableTracker
4646
from .constant import ConstantVariable
47-
from .lists import ListIteratorVariable
4847

4948

5049
if TYPE_CHECKING:
@@ -792,6 +791,8 @@ def call_method(
792791
self.call_method(tx, "update", args, kwargs)
793792
return self
794793
elif name == "__iter__":
794+
from .lists import ListIteratorVariable
795+
795796
if self.source and not is_constant_source(self.source):
796797
tx.output.guard_on_key_order.add(self.source)
797798
return ListIteratorVariable(
@@ -1462,6 +1463,8 @@ def call_method(
14621463
if name == "__len__":
14631464
return self.dv_dict.call_method(tx, name, args, kwargs)
14641465
elif name == "__iter__":
1466+
from .lists import ListIteratorVariable
1467+
14651468
return ListIteratorVariable(
14661469
self.view_items_vt, mutation_type=ValueMutationNew()
14671470
)

torch/_dynamo/variables/lists.py

Lines changed: 61 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ class that handles its unique behaviors while integrating with Dynamo's
1515
"""
1616

1717
import collections
18-
import inspect
1918
import operator
2019
import sys
2120
from collections.abc import Sequence
@@ -39,7 +38,6 @@ class that handles its unique behaviors while integrating with Dynamo's
3938
get_fake_value,
4039
guard_if_dyn,
4140
iter_contains,
42-
Lit,
4341
namedtuple_fields,
4442
odict_values,
4543
raise_args_mismatch,
@@ -48,8 +46,8 @@ class that handles its unique behaviors while integrating with Dynamo's
4846
)
4947
from .base import ValueMutationNew, VariableTracker
5048
from .constant import ConstantVariable
51-
from .functions import UserFunctionVariable, UserMethodVariable
5249
from .iter import IteratorVariable
50+
from .user_defined import UserDefinedTupleVariable
5351

5452

5553
if TYPE_CHECKING:
@@ -1296,24 +1294,51 @@ def call_obj_hasattr(
12961294
return variables.ConstantVariable.create(hasattr(torch.Size, name))
12971295

12981296

1299-
class NamedTupleVariable(TupleVariable):
1297+
class NamedTupleVariable(UserDefinedTupleVariable):
13001298
_nonvar_fields = {
13011299
"tuple_cls",
13021300
"dynamic_attributes",
1303-
*TupleVariable._nonvar_fields,
1301+
*UserDefinedTupleVariable._nonvar_fields,
13041302
}
13051303

13061304
def __init__(
13071305
self,
13081306
items: list[VariableTracker],
1309-
tuple_cls: type,
1307+
tuple_cls: type[tuple],
13101308
dynamic_attributes: Optional[dict[str, VariableTracker]] = None,
13111309
**kwargs: Any,
13121310
) -> None:
1313-
super().__init__(items, **kwargs)
1311+
tuple_vt = variables.TupleVariable(
1312+
items, mutation_type=kwargs.get("mutation_type", ValueMutationNew())
1313+
)
1314+
1315+
# Create a dummy instance for method resolution
1316+
# This allows _maybe_get_baseclass_method to work correctly
1317+
fields = namedtuple_fields(tuple_cls)
1318+
num_fields = len(fields)
1319+
if tuple_cls.__module__ == "torch.return_types":
1320+
# Structseq: single iterable argument
1321+
dummy_value = tuple_cls([None] * num_fields)
1322+
else:
1323+
# Namedtuple: positional arguments
1324+
dummy_value = tuple_cls(*([None] * num_fields)) # type: ignore[arg-type]
1325+
1326+
super().__init__(
1327+
value=dummy_value,
1328+
tuple_vt=tuple_vt,
1329+
init_args=None,
1330+
**kwargs,
1331+
)
1332+
13141333
self.tuple_cls = tuple_cls
1334+
if len(self.tuple_cls.__mro__) < 3:
1335+
raise ValueError("NamedTuple should inherit from Tuple and Object.")
13151336
self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {}
13161337

1338+
@property
1339+
def items(self) -> list[VariableTracker]:
1340+
return self._tuple_vt.items
1341+
13171342
def is_namedtuple(self) -> bool:
13181343
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(
13191344
getattr(self.tuple_cls, "_make", None)
@@ -1325,17 +1350,7 @@ def is_structseq(self) -> bool:
13251350
def fields(self) -> tuple[str, ...]:
13261351
return namedtuple_fields(self.tuple_cls)
13271352

1328-
def debug_repr(self) -> str:
1329-
if self.is_structseq():
1330-
# StructSequenceType(iterable)
1331-
return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items]))
1332-
# NamedTupleType(*iterable)
1333-
return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items)))
1334-
1335-
def python_type(self) -> type:
1336-
return self.tuple_cls
1337-
1338-
def as_python_constant(self) -> Any:
1353+
def as_python_constant(self):
13391354
if self.is_structseq():
13401355
# StructSequenceType(iterable)
13411356
result = self.python_type()([x.as_python_constant() for x in self.items])
@@ -1357,57 +1372,39 @@ def as_python_constant(self) -> Any:
13571372

13581373
return result
13591374

1360-
def as_proxy(self) -> Any:
1361-
assert self.python_type() is not SizeVariable
1375+
def as_proxy(self):
13621376
if self.is_structseq():
1363-
# StructSequenceType(iterable)
1364-
return self.python_type()(self._as_proxy())
1365-
# NamedTupleType(*iterable)
1366-
return self.python_type()(*self._as_proxy())
1377+
return self.python_type()([x.as_proxy() for x in self._tuple_vt.items])
1378+
return self.python_type()(*[x.as_proxy() for x in self._tuple_vt.items])
13671379

13681380
def reconstruct(self, codegen: "PyCodegen") -> None:
1369-
# Always reconstruct the NamedTuple normally first
1370-
# Constructors:
1371-
# StructSequenceType(iterable)
1372-
# NamedTupleType(*iterable)
1373-
# NamedTupleType._make(iterable)
13741381
if self.is_structseq():
13751382
create_fn = self.tuple_cls
13761383
else:
13771384
create_fn = self.tuple_cls._make # type: ignore[attr-defined]
1385+
13781386
codegen.add_push_null(
13791387
lambda: codegen.append_output(
13801388
codegen.create_load_const_unchecked(create_fn)
13811389
)
13821390
)
1383-
codegen.foreach(self.items)
1391+
codegen.foreach(self._tuple_vt.items)
13841392
codegen.extend_output(
13851393
[
1386-
create_build_tuple(len(self.items)),
1394+
create_build_tuple(len(self._tuple_vt.items)),
13871395
]
13881396
+ create_call_function(1, False)
13891397
)
13901398

1399+
# Apply initial dynamic attributes after construction (if any)
1400+
# Runtime dynamic attributes are tracked via side effects system
13911401
for name, value in self.dynamic_attributes.items():
13921402
codegen.dup_top()
13931403
codegen(value)
13941404
codegen.extend_output(create_rot_n(2))
13951405
codegen.store_attr(name)
13961406

13971407
def _is_method_overridden(self, method_name: str) -> bool:
1398-
"""Checks if a method is overridden in the NamedTuple subclass.
1399-
1400-
Args:
1401-
method_name (str): The name of the method to check.
1402-
1403-
Returns:
1404-
bool: True if the method is overridden in the subclass, False otherwise.
1405-
1406-
Raises:
1407-
ValueError: If the NamedTuple class does not inherit from both Tuple and Object.
1408-
"""
1409-
if len(self.tuple_cls.__mro__) < 3:
1410-
raise ValueError("NamedTuple should inherit from Tuple and Object.")
14111408
if getattr(self.tuple_cls, method_name, None) == getattr(
14121409
self.tuple_cls.__mro__[-3], method_name, None
14131410
):
@@ -1421,129 +1418,53 @@ def call_method(
14211418
args: list[VariableTracker],
14221419
kwargs: dict[str, VariableTracker],
14231420
) -> VariableTracker:
1424-
if name == "__setattr__":
1421+
if self._is_method_overridden(name):
1422+
# Fall back to UserDefinedTupleVariable
1423+
return super().call_method(tx, name, args, kwargs)
1424+
elif name == "__setattr__":
14251425
if kwargs or len(args) != 2:
14261426
raise_args_mismatch(
14271427
tx,
14281428
name,
14291429
"2 args and 0 kwargs",
14301430
f"{len(args)} args and {len(kwargs)} kwargs",
14311431
)
1432-
attr, value = args
1433-
attr = attr.as_python_constant()
1432+
attr_var, value = args
1433+
attr = attr_var.as_python_constant()
1434+
14341435
if (
14351436
# structseq is immutable
14361437
self.is_structseq()
14371438
# namedtuple directly created by `collections.namedtuple` is immutable
14381439
or self.tuple_cls.__bases__ == (tuple,)
1439-
# fields are immutable
14401440
or attr in self.fields()
14411441
):
14421442
raise_observed_exception(AttributeError, tx)
1443-
# Subclass of namedtuple type can have dynamic attributes
1444-
tx.output.side_effects.mutation(self)
1445-
if self.source:
1446-
tx.output.side_effects.store_attr(self, attr, value)
1447-
self.dynamic_attributes[attr] = value
1448-
return ConstantVariable.create(None)
1449-
elif name == "_replace":
1450-
# NamedTuple._replace should create a new instance with replaced fields
1451-
if args:
1452-
raise_args_mismatch(tx, name, "0 args", f"{len(args)} args")
1453-
1454-
# Get the field names for validation
1455-
fields = self.fields()
1456-
1457-
# Start with current items (copy them)
1458-
new_items = list(self.items)
1459-
1460-
# Replace fields specified in kwargs
1461-
for field_name, new_value in kwargs.items():
1462-
if field_name not in fields:
1463-
raise_observed_exception(
1464-
ValueError,
1465-
tx,
1466-
args=[
1467-
ConstantVariable.create(
1468-
f"Got unexpected field name: '{field_name}'"
1469-
)
1470-
],
1471-
)
1472-
1473-
# Replace the item at the field's index
1474-
field_index = fields.index(field_name)
1475-
new_items[field_index] = new_value
14761443

1477-
return NamedTupleVariable(new_items, self.tuple_cls)
1444+
result = self.method_setattr_standard(tx, attr_var, value)
1445+
# Also update self.dynamic_attributes
1446+
self.dynamic_attributes[attr] = value
1447+
return result
14781448

14791449
return super().call_method(tx, name, args, kwargs)
14801450

1481-
def getitem_const(
1482-
self, tx: "InstructionTranslator", arg: VariableTracker
1483-
) -> VariableTracker:
1484-
if isinstance(arg, SliceVariable):
1485-
# slicing a namedtuple produces a tuple
1486-
return TupleVariable(
1487-
self.items[arg.as_python_constant()],
1488-
source=None,
1489-
)
1490-
return super().getitem_const(tx, arg)
1491-
1492-
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
1493-
def check_and_create_method() -> Optional[VariableTracker]:
1494-
method = inspect.getattr_static(self.tuple_cls, name, None)
1495-
if isinstance(method, classmethod):
1496-
# We need the unbounded cls method to avoid the inline __self__
1497-
return UserMethodVariable(
1498-
method.__func__,
1499-
variables.UserDefinedClassVariable(self.tuple_cls),
1500-
)
1501-
elif isinstance(method, staticmethod):
1502-
# pyrefly: ignore[bad-argument-type]
1503-
return UserFunctionVariable(method.__func__)
1504-
elif inspect.isfunction(method):
1505-
return UserMethodVariable(method, self)
1506-
else:
1507-
return None
1508-
1509-
# Avoid UserMethodVariable fallback precisely when methods NamedTuple methods have not been overwritten.
1510-
if (
1511-
name == "_replace"
1512-
and not self._is_method_overridden("_replace")
1513-
and not self._is_method_overridden("__getattr__")
1514-
):
1515-
# Return a BuiltinVariable for the _replace method
1516-
# Get the actual _replace method from the tuple class
1517-
actual_replace_method = getattr(self.tuple_cls, "_replace", None)
1518-
if actual_replace_method:
1519-
from ..source import AttrSource
1520-
1521-
source = AttrSource(self.source, name) if self.source else None
1522-
return variables.GetAttrVariable(self, name, source=source)
1523-
# Fallback if _replace doesn't exist (shouldn't happen for proper NamedTuples)
1524-
return super().var_getattr(tx, name)
1451+
def python_type(self) -> type:
1452+
return self.tuple_cls
15251453

1454+
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
15261455
if name == "_fields":
1527-
result_source = NamedTupleFieldsSource(self.source) if self.source else None
1528-
return VariableTracker.build(tx, self.fields(), source=result_source)
1456+
source = NamedTupleFieldsSource(self.source) if self.source else None
1457+
return VariableTracker.build(tx, self.fields(), source=source)
15291458

15301459
if name in self.dynamic_attributes:
15311460
return self.dynamic_attributes[name]
15321461

15331462
fields = self.fields()
1534-
if name not in fields:
1535-
method = check_and_create_method()
1536-
if not method:
1537-
return super().var_getattr(tx, name)
1538-
return method
1539-
return self.items[fields.index(name)]
1463+
if name in fields:
1464+
field_index = fields.index(name)
1465+
return self._tuple_vt.items[field_index]
15401466

1541-
def call_obj_hasattr(
1542-
self, tx: "InstructionTranslator", name: str
1543-
) -> VariableTracker:
1544-
return variables.ConstantVariable.create(
1545-
name in self.dynamic_attributes or hasattr(self.tuple_cls, name)
1546-
)
1467+
return super().var_getattr(tx, name)
15471468

15481469

15491470
class SliceVariable(VariableTracker):

0 commit comments

Comments
 (0)