Skip to content

Commit 7a064ed

Browse files
Revert "Change NamedTupleVariable implementation to subclass UserDefinedTupleVariable (pytorch#167468)"
This reverts commit c055ebe. Reverted pytorch#167468 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](pytorch#167468 (comment)))
1 parent c3320ed commit 7a064ed

File tree

4 files changed

+149
-85
lines changed

4 files changed

+149
-85
lines changed

test/dynamo/test_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2053,7 +2053,7 @@ def test_namedtuple_defaults(a, b):
20532053
return mytuple(tmp.x, tmp[1], tmp.xy + b)
20542054

20552055
@make_test
2056-
def test_namedtuple_replace_1(a, b):
2056+
def test_namedtuple_replace(a, b):
20572057
mytuple = collections.namedtuple("mytuple", ["x", "y"])
20582058
t = mytuple(a, b)
20592059
t._replace(x=b)
@@ -2109,7 +2109,7 @@ def test_namedtuple_user_methods(a, b):
21092109
return mytuple.add(), mytuple.static_method(), mytuple.class_method()
21102110

21112111
@make_test
2112-
def test_namedtuple_replace_2(a, b):
2112+
def test_namedtuple_replace(a, b):
21132113
mytuple = FunctionTests.MyNamedTuple(a, b)
21142114
replaced = mytuple._replace(first=b)
21152115
return mytuple.first + mytuple.second + replaced.first + replaced.second

torch/_dynamo/variables/dicts.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
)
4545
from .base import ValueMutationNew, VariableTracker
4646
from .constant import ConstantVariable
47+
from .lists import ListIteratorVariable
4748

4849

4950
if TYPE_CHECKING:
@@ -791,8 +792,6 @@ def call_method(
791792
self.call_method(tx, "update", args, kwargs)
792793
return self
793794
elif name == "__iter__":
794-
from .lists import ListIteratorVariable
795-
796795
if self.source and not is_constant_source(self.source):
797796
tx.output.guard_on_key_order.add(self.source)
798797
return ListIteratorVariable(
@@ -1463,8 +1462,6 @@ def call_method(
14631462
if name == "__len__":
14641463
return self.dv_dict.call_method(tx, name, args, kwargs)
14651464
elif name == "__iter__":
1466-
from .lists import ListIteratorVariable
1467-
14681465
return ListIteratorVariable(
14691466
self.view_items_vt, mutation_type=ValueMutationNew()
14701467
)

torch/_dynamo/variables/lists.py

Lines changed: 140 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class that handles its unique behaviors while integrating with Dynamo's
1515
"""
1616

1717
import collections
18+
import inspect
1819
import operator
1920
import sys
2021
from collections.abc import Sequence
@@ -38,6 +39,7 @@ class that handles its unique behaviors while integrating with Dynamo's
3839
get_fake_value,
3940
guard_if_dyn,
4041
iter_contains,
42+
Lit,
4143
namedtuple_fields,
4244
odict_values,
4345
raise_args_mismatch,
@@ -46,8 +48,8 @@ class that handles its unique behaviors while integrating with Dynamo's
4648
)
4749
from .base import ValueMutationNew, VariableTracker
4850
from .constant import ConstantVariable
51+
from .functions import UserFunctionVariable, UserMethodVariable
4952
from .iter import IteratorVariable
50-
from .user_defined import UserDefinedTupleVariable
5153

5254

5355
if TYPE_CHECKING:
@@ -1294,51 +1296,24 @@ def call_obj_hasattr(
12941296
return variables.ConstantVariable.create(hasattr(torch.Size, name))
12951297

12961298

1297-
class NamedTupleVariable(UserDefinedTupleVariable):
1299+
class NamedTupleVariable(TupleVariable):
12981300
_nonvar_fields = {
12991301
"tuple_cls",
13001302
"dynamic_attributes",
1301-
*UserDefinedTupleVariable._nonvar_fields,
1303+
*TupleVariable._nonvar_fields,
13021304
}
13031305

13041306
def __init__(
13051307
self,
13061308
items: list[VariableTracker],
1307-
tuple_cls: type[tuple],
1309+
tuple_cls: type,
13081310
dynamic_attributes: Optional[dict[str, VariableTracker]] = None,
13091311
**kwargs: Any,
13101312
) -> None:
1311-
tuple_vt = variables.TupleVariable(
1312-
items, mutation_type=kwargs.get("mutation_type", ValueMutationNew())
1313-
)
1314-
1315-
# Create a dummy instance for method resolution
1316-
# This allows _maybe_get_baseclass_method to work correctly
1317-
fields = namedtuple_fields(tuple_cls)
1318-
num_fields = len(fields)
1319-
if tuple_cls.__module__ == "torch.return_types":
1320-
# Structseq: single iterable argument
1321-
dummy_value = tuple_cls([None] * num_fields)
1322-
else:
1323-
# Namedtuple: positional arguments
1324-
dummy_value = tuple_cls(*([None] * num_fields)) # type: ignore[arg-type]
1325-
1326-
super().__init__(
1327-
value=dummy_value,
1328-
tuple_vt=tuple_vt,
1329-
init_args=None,
1330-
**kwargs,
1331-
)
1332-
1313+
super().__init__(items, **kwargs)
13331314
self.tuple_cls = tuple_cls
1334-
if len(self.tuple_cls.__mro__) < 3:
1335-
raise ValueError("NamedTuple should inherit from Tuple and Object.")
13361315
self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {}
13371316

1338-
@property
1339-
def items(self) -> list[VariableTracker]:
1340-
return self._tuple_vt.items
1341-
13421317
def is_namedtuple(self) -> bool:
13431318
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(
13441319
getattr(self.tuple_cls, "_make", None)
@@ -1350,7 +1325,17 @@ def is_structseq(self) -> bool:
13501325
def fields(self) -> tuple[str, ...]:
13511326
return namedtuple_fields(self.tuple_cls)
13521327

1353-
def as_python_constant(self):
1328+
def debug_repr(self) -> str:
1329+
if self.is_structseq():
1330+
# StructSequenceType(iterable)
1331+
return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items]))
1332+
# NamedTupleType(*iterable)
1333+
return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items)))
1334+
1335+
def python_type(self) -> type:
1336+
return self.tuple_cls
1337+
1338+
def as_python_constant(self) -> Any:
13541339
if self.is_structseq():
13551340
# StructSequenceType(iterable)
13561341
result = self.python_type()([x.as_python_constant() for x in self.items])
@@ -1372,39 +1357,57 @@ def as_python_constant(self):
13721357

13731358
return result
13741359

1375-
def as_proxy(self):
1360+
def as_proxy(self) -> Any:
1361+
assert self.python_type() is not SizeVariable
13761362
if self.is_structseq():
1377-
return self.python_type()([x.as_proxy() for x in self._tuple_vt.items])
1378-
return self.python_type()(*[x.as_proxy() for x in self._tuple_vt.items])
1363+
# StructSequenceType(iterable)
1364+
return self.python_type()(self._as_proxy())
1365+
# NamedTupleType(*iterable)
1366+
return self.python_type()(*self._as_proxy())
13791367

13801368
def reconstruct(self, codegen: "PyCodegen") -> None:
1369+
# Always reconstruct the NamedTuple normally first
1370+
# Constructors:
1371+
# StructSequenceType(iterable)
1372+
# NamedTupleType(*iterable)
1373+
# NamedTupleType._make(iterable)
13811374
if self.is_structseq():
13821375
create_fn = self.tuple_cls
13831376
else:
13841377
create_fn = self.tuple_cls._make # type: ignore[attr-defined]
1385-
13861378
codegen.add_push_null(
13871379
lambda: codegen.append_output(
13881380
codegen.create_load_const_unchecked(create_fn)
13891381
)
13901382
)
1391-
codegen.foreach(self._tuple_vt.items)
1383+
codegen.foreach(self.items)
13921384
codegen.extend_output(
13931385
[
1394-
create_build_tuple(len(self._tuple_vt.items)),
1386+
create_build_tuple(len(self.items)),
13951387
]
13961388
+ create_call_function(1, False)
13971389
)
13981390

1399-
# Apply initial dynamic attributes after construction (if any)
1400-
# Runtime dynamic attributes are tracked via side effects system
14011391
for name, value in self.dynamic_attributes.items():
14021392
codegen.dup_top()
14031393
codegen(value)
14041394
codegen.extend_output(create_rot_n(2))
14051395
codegen.store_attr(name)
14061396

14071397
def _is_method_overridden(self, method_name: str) -> bool:
1398+
"""Checks if a method is overridden in the NamedTuple subclass.
1399+
1400+
Args:
1401+
method_name (str): The name of the method to check.
1402+
1403+
Returns:
1404+
bool: True if the method is overridden in the subclass, False otherwise.
1405+
1406+
Raises:
1407+
ValueError: If the NamedTuple class does not inherit from both Tuple and Object.
1408+
"""
1409+
if len(self.tuple_cls.__mro__) < 3:
1410+
raise ValueError("NamedTuple should inherit from Tuple and Object.")
14081411
if getattr(self.tuple_cls, method_name, None) == getattr(
14091412
self.tuple_cls.__mro__[-3], method_name, None
14101413
):
@@ -1418,53 +1421,129 @@ def call_method(
14181421
args: list[VariableTracker],
14191422
kwargs: dict[str, VariableTracker],
14201423
) -> VariableTracker:
1421-
if self._is_method_overridden(name):
1422-
# Fall back to UserDefinedTupleVariable
1423-
return super().call_method(tx, name, args, kwargs)
1424-
elif name == "__setattr__":
1424+
if name == "__setattr__":
14251425
if kwargs or len(args) != 2:
14261426
raise_args_mismatch(
14271427
tx,
14281428
name,
14291429
"2 args and 0 kwargs",
14301430
f"{len(args)} args and {len(kwargs)} kwargs",
14311431
)
1432-
attr_var, value = args
1433-
attr = attr_var.as_python_constant()
1434-
1432+
attr, value = args
1433+
attr = attr.as_python_constant()
14351434
if (
14361435
# structseq is immutable
14371436
self.is_structseq()
14381437
# namedtuple directly created by `collections.namedtuple` is immutable
14391438
or self.tuple_cls.__bases__ == (tuple,)
1439+
# fields are immutable
14401440
or attr in self.fields()
14411441
):
14421442
raise_observed_exception(AttributeError, tx)
1443-
1444-
result = self.method_setattr_standard(tx, attr_var, value)
1445-
# Also update self.dynamic_attributes
1443+
# Subclass of namedtuple type can have dynamic attributes
1444+
tx.output.side_effects.mutation(self)
1445+
if self.source:
1446+
tx.output.side_effects.store_attr(self, attr, value)
14461447
self.dynamic_attributes[attr] = value
1447-
return result
1448+
return ConstantVariable.create(None)
1449+
elif name == "_replace":
1450+
# NamedTuple._replace should create a new instance with replaced fields
1451+
if args:
1452+
raise_args_mismatch(tx, name, "0 args", f"{len(args)} args")
1453+
1454+
# Get the field names for validation
1455+
fields = self.fields()
1456+
1457+
# Start with current items (copy them)
1458+
new_items = list(self.items)
1459+
1460+
# Replace fields specified in kwargs
1461+
for field_name, new_value in kwargs.items():
1462+
if field_name not in fields:
1463+
raise_observed_exception(
1464+
ValueError,
1465+
tx,
1466+
args=[
1467+
ConstantVariable.create(
1468+
f"Got unexpected field name: '{field_name}'"
1469+
)
1470+
],
1471+
)
1472+
1473+
# Replace the item at the field's index
1474+
field_index = fields.index(field_name)
1475+
new_items[field_index] = new_value
1476+
1477+
return NamedTupleVariable(new_items, self.tuple_cls)
14481478

14491479
return super().call_method(tx, name, args, kwargs)
14501480

1451-
def python_type(self) -> type:
1452-
return self.tuple_cls
1481+
def getitem_const(
1482+
self, tx: "InstructionTranslator", arg: VariableTracker
1483+
) -> VariableTracker:
1484+
if isinstance(arg, SliceVariable):
1485+
# slicing a namedtuple produces a tuple
1486+
return TupleVariable(
1487+
self.items[arg.as_python_constant()],
1488+
source=None,
1489+
)
1490+
return super().getitem_const(tx, arg)
1491+
1492+
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
1493+
def check_and_create_method() -> Optional[VariableTracker]:
1494+
method = inspect.getattr_static(self.tuple_cls, name, None)
1495+
if isinstance(method, classmethod):
1496+
# We need the unbounded cls method to avoid the inline __self__
1497+
return UserMethodVariable(
1498+
method.__func__,
1499+
variables.UserDefinedClassVariable(self.tuple_cls),
1500+
)
1501+
elif isinstance(method, staticmethod):
1502+
# pyrefly: ignore[bad-argument-type]
1503+
return UserFunctionVariable(method.__func__)
1504+
elif inspect.isfunction(method):
1505+
return UserMethodVariable(method, self)
1506+
else:
1507+
return None
1508+
1509+
# Avoid UserMethodVariable fallback precisely when methods NamedTuple methods have not been overwritten.
1510+
if (
1511+
name == "_replace"
1512+
and not self._is_method_overridden("_replace")
1513+
and not self._is_method_overridden("__getattr__")
1514+
):
1515+
# Return a BuiltinVariable for the _replace method
1516+
# Get the actual _replace method from the tuple class
1517+
actual_replace_method = getattr(self.tuple_cls, "_replace", None)
1518+
if actual_replace_method:
1519+
from ..source import AttrSource
1520+
1521+
source = AttrSource(self.source, name) if self.source else None
1522+
return variables.GetAttrVariable(self, name, source=source)
1523+
# Fallback if _replace doesn't exist (shouldn't happen for proper NamedTuples)
1524+
return super().var_getattr(tx, name)
14531525

1454-
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
14551526
if name == "_fields":
1456-
source = NamedTupleFieldsSource(self.source) if self.source else None
1457-
return VariableTracker.build(tx, self.fields(), source=source)
1527+
result_source = NamedTupleFieldsSource(self.source) if self.source else None
1528+
return VariableTracker.build(tx, self.fields(), source=result_source)
14581529

14591530
if name in self.dynamic_attributes:
14601531
return self.dynamic_attributes[name]
14611532

14621533
fields = self.fields()
1463-
if name in fields:
1464-
field_index = fields.index(name)
1465-
return self._tuple_vt.items[field_index]
1534+
if name not in fields:
1535+
method = check_and_create_method()
1536+
if not method:
1537+
return super().var_getattr(tx, name)
1538+
return method
1539+
return self.items[fields.index(name)]
14661540

1467-
return super().var_getattr(tx, name)
1541+
def call_obj_hasattr(
1542+
self, tx: "InstructionTranslator", name: str
1543+
) -> VariableTracker:
1544+
return variables.ConstantVariable.create(
1545+
name in self.dynamic_attributes or hasattr(self.tuple_cls, name)
1546+
)
14681547

14691548

14701549
class SliceVariable(VariableTracker):

0 commit comments

Comments
 (0)