@@ -15,6 +15,7 @@ class that handles its unique behaviors while integrating with Dynamo's
1515"""
1616
1717import collections
18+ import inspect
1819import operator
1920import sys
2021from collections .abc import Sequence
@@ -38,6 +39,7 @@ class that handles its unique behaviors while integrating with Dynamo's
3839 get_fake_value ,
3940 guard_if_dyn ,
4041 iter_contains ,
42+ Lit ,
4143 namedtuple_fields ,
4244 odict_values ,
4345 raise_args_mismatch ,
@@ -46,8 +48,8 @@ class that handles its unique behaviors while integrating with Dynamo's
4648)
4749from .base import ValueMutationNew , VariableTracker
4850from .constant import ConstantVariable
51+ from .functions import UserFunctionVariable , UserMethodVariable
4952from .iter import IteratorVariable
50- from .user_defined import UserDefinedTupleVariable
5153
5254
5355if TYPE_CHECKING :
@@ -1294,51 +1296,24 @@ def call_obj_hasattr(
12941296 return variables .ConstantVariable .create (hasattr (torch .Size , name ))
12951297
12961298
1297- class NamedTupleVariable (UserDefinedTupleVariable ):
1299+ class NamedTupleVariable (TupleVariable ):
12981300 _nonvar_fields = {
12991301 "tuple_cls" ,
13001302 "dynamic_attributes" ,
1301- * UserDefinedTupleVariable ._nonvar_fields ,
1303+ * TupleVariable ._nonvar_fields ,
13021304 }
13031305
13041306 def __init__ (
13051307 self ,
13061308 items : list [VariableTracker ],
1307- tuple_cls : type [ tuple ] ,
1309+ tuple_cls : type ,
13081310 dynamic_attributes : Optional [dict [str , VariableTracker ]] = None ,
13091311 ** kwargs : Any ,
13101312 ) -> None :
1311- tuple_vt = variables .TupleVariable (
1312- items , mutation_type = kwargs .get ("mutation_type" , ValueMutationNew ())
1313- )
1314-
1315- # Create a dummy instance for method resolution
1316- # This allows _maybe_get_baseclass_method to work correctly
1317- fields = namedtuple_fields (tuple_cls )
1318- num_fields = len (fields )
1319- if tuple_cls .__module__ == "torch.return_types" :
1320- # Structseq: single iterable argument
1321- dummy_value = tuple_cls ([None ] * num_fields )
1322- else :
1323- # Namedtuple: positional arguments
1324- dummy_value = tuple_cls (* ([None ] * num_fields )) # type: ignore[arg-type]
1325-
1326- super ().__init__ (
1327- value = dummy_value ,
1328- tuple_vt = tuple_vt ,
1329- init_args = None ,
1330- ** kwargs ,
1331- )
1332-
1313+ super ().__init__ (items , ** kwargs )
13331314 self .tuple_cls = tuple_cls
1334- if len (self .tuple_cls .__mro__ ) < 3 :
1335- raise ValueError ("NamedTuple should inherit from Tuple and Object." )
13361315 self .dynamic_attributes = dynamic_attributes if dynamic_attributes else {}
13371316
1338- @property
1339- def items (self ) -> list [VariableTracker ]:
1340- return self ._tuple_vt .items
1341-
13421317 def is_namedtuple (self ) -> bool :
13431318 return isinstance (getattr (self .tuple_cls , "_fields" , None ), tuple ) and callable (
13441319 getattr (self .tuple_cls , "_make" , None )
@@ -1350,7 +1325,17 @@ def is_structseq(self) -> bool:
13501325 def fields (self ) -> tuple [str , ...]:
13511326 return namedtuple_fields (self .tuple_cls )
13521327
1353- def as_python_constant (self ):
1328+ def debug_repr (self ) -> str :
1329+ if self .is_structseq ():
1330+ # StructSequenceType(iterable)
1331+ return repr (self .tuple_cls ([Lit (x .debug_repr ()) for x in self .items ]))
1332+ # NamedTupleType(*iterable)
1333+ return repr (self .tuple_cls (* (Lit (x .debug_repr ()) for x in self .items )))
1334+
1335+ def python_type (self ) -> type :
1336+ return self .tuple_cls
1337+
1338+ def as_python_constant (self ) -> Any :
13541339 if self .is_structseq ():
13551340 # StructSequenceType(iterable)
13561341 result = self .python_type ()([x .as_python_constant () for x in self .items ])
@@ -1372,39 +1357,57 @@ def as_python_constant(self):
13721357
13731358 return result
13741359
1375- def as_proxy (self ):
1360+ def as_proxy (self ) -> Any :
1361+ assert self .python_type () is not SizeVariable
13761362 if self .is_structseq ():
1377- return self .python_type ()([x .as_proxy () for x in self ._tuple_vt .items ])
1378- return self .python_type ()(* [x .as_proxy () for x in self ._tuple_vt .items ])
1363+ # StructSequenceType(iterable)
1364+ return self .python_type ()(self ._as_proxy ())
1365+ # NamedTupleType(*iterable)
1366+ return self .python_type ()(* self ._as_proxy ())
13791367
13801368 def reconstruct (self , codegen : "PyCodegen" ) -> None :
1369+ # Always reconstruct the NamedTuple normally first
1370+ # Constructors:
1371+ # StructSequenceType(iterable)
1372+ # NamedTupleType(*iterable)
1373+ # NamedTupleType._make(iterable)
13811374 if self .is_structseq ():
13821375 create_fn = self .tuple_cls
13831376 else :
13841377 create_fn = self .tuple_cls ._make # type: ignore[attr-defined]
1385-
13861378 codegen .add_push_null (
13871379 lambda : codegen .append_output (
13881380 codegen .create_load_const_unchecked (create_fn )
13891381 )
13901382 )
1391- codegen .foreach (self ._tuple_vt . items )
1383+ codegen .foreach (self .items )
13921384 codegen .extend_output (
13931385 [
1394- create_build_tuple (len (self ._tuple_vt . items )),
1386+ create_build_tuple (len (self .items )),
13951387 ]
13961388 + create_call_function (1 , False )
13971389 )
13981390
1399- # Apply initial dynamic attributes after construction (if any)
1400- # Runtime dynamic attributes are tracked via side effects system
14011391 for name , value in self .dynamic_attributes .items ():
14021392 codegen .dup_top ()
14031393 codegen (value )
14041394 codegen .extend_output (create_rot_n (2 ))
14051395 codegen .store_attr (name )
14061396
14071397 def _is_method_overridden (self , method_name : str ) -> bool :
1398+ """Checks if a method is overridden in the NamedTuple subclass.
1399+
1400+ Args:
1401+ method_name (str): The name of the method to check.
1402+
1403+ Returns:
1404+ bool: True if the method is overridden in the subclass, False otherwise.
1405+
1406+ Raises:
1407+ ValueError: If the NamedTuple class does not inherit from both Tuple and Object.
1408+ """
1409+ if len (self .tuple_cls .__mro__ ) < 3 :
1410+ raise ValueError ("NamedTuple should inherit from Tuple and Object." )
14081411 if getattr (self .tuple_cls , method_name , None ) == getattr (
14091412 self .tuple_cls .__mro__ [- 3 ], method_name , None
14101413 ):
@@ -1418,53 +1421,129 @@ def call_method(
14181421 args : list [VariableTracker ],
14191422 kwargs : dict [str , VariableTracker ],
14201423 ) -> VariableTracker :
1421- if self ._is_method_overridden (name ):
1422- # Fall back to UserDefinedTupleVariable
1423- return super ().call_method (tx , name , args , kwargs )
1424- elif name == "__setattr__" :
1424+ if name == "__setattr__" :
14251425 if kwargs or len (args ) != 2 :
14261426 raise_args_mismatch (
14271427 tx ,
14281428 name ,
14291429 "2 args and 0 kwargs" ,
14301430 f"{ len (args )} args and { len (kwargs )} kwargs" ,
14311431 )
1432- attr_var , value = args
1433- attr = attr_var .as_python_constant ()
1434-
1432+ attr , value = args
1433+ attr = attr .as_python_constant ()
14351434 if (
14361435 # structseq is immutable
14371436 self .is_structseq ()
14381437 # namedtuple directly created by `collections.namedtuple` is immutable
14391438 or self .tuple_cls .__bases__ == (tuple ,)
1439+ # fields are immutable
14401440 or attr in self .fields ()
14411441 ):
14421442 raise_observed_exception (AttributeError , tx )
1443-
1444- result = self .method_setattr_standard (tx , attr_var , value )
1445- # Also update self.dynamic_attributes
1443+ # Subclass of namedtuple type can have dynamic attributes
1444+ tx .output .side_effects .mutation (self )
1445+ if self .source :
1446+ tx .output .side_effects .store_attr (self , attr , value )
14461447 self .dynamic_attributes [attr ] = value
1447- return result
1448+ return ConstantVariable .create (None )
1449+ elif name == "_replace" :
1450+ # NamedTuple._replace should create a new instance with replaced fields
1451+ if args :
1452+ raise_args_mismatch (tx , name , "0 args" , f"{ len (args )} args" )
1453+
1454+ # Get the field names for validation
1455+ fields = self .fields ()
1456+
1457+ # Start with current items (copy them)
1458+ new_items = list (self .items )
1459+
1460+ # Replace fields specified in kwargs
1461+ for field_name , new_value in kwargs .items ():
1462+ if field_name not in fields :
1463+ raise_observed_exception (
1464+ ValueError ,
1465+ tx ,
1466+ args = [
1467+ ConstantVariable .create (
1468+ f"Got unexpected field name: '{ field_name } '"
1469+ )
1470+ ],
1471+ )
1472+
1473+ # Replace the item at the field's index
1474+ field_index = fields .index (field_name )
1475+ new_items [field_index ] = new_value
1476+
1477+ return NamedTupleVariable (new_items , self .tuple_cls )
14481478
14491479 return super ().call_method (tx , name , args , kwargs )
14501480
1451- def python_type (self ) -> type :
1452- return self .tuple_cls
1481+ def getitem_const (
1482+ self , tx : "InstructionTranslator" , arg : VariableTracker
1483+ ) -> VariableTracker :
1484+ if isinstance (arg , SliceVariable ):
1485+ # slicing a namedtuple produces a tuple
1486+ return TupleVariable (
1487+ self .items [arg .as_python_constant ()],
1488+ source = None ,
1489+ )
1490+ return super ().getitem_const (tx , arg )
1491+
1492+ def var_getattr (self , tx : "InstructionTranslator" , name : str ) -> VariableTracker :
1493+ def check_and_create_method () -> Optional [VariableTracker ]:
1494+ method = inspect .getattr_static (self .tuple_cls , name , None )
1495+ if isinstance (method , classmethod ):
1496+ # We need the unbounded cls method to avoid the inline __self__
1497+ return UserMethodVariable (
1498+ method .__func__ ,
1499+ variables .UserDefinedClassVariable (self .tuple_cls ),
1500+ )
1501+ elif isinstance (method , staticmethod ):
1502+ # pyrefly: ignore[bad-argument-type]
1503+ return UserFunctionVariable (method .__func__ )
1504+ elif inspect .isfunction (method ):
1505+ return UserMethodVariable (method , self )
1506+ else :
1507+ return None
1508+
1509+ # Avoid UserMethodVariable fallback precisely when methods NamedTuple methods have not been overwritten.
1510+ if (
1511+ name == "_replace"
1512+ and not self ._is_method_overridden ("_replace" )
1513+ and not self ._is_method_overridden ("__getattr__" )
1514+ ):
1515+ # Return a BuiltinVariable for the _replace method
1516+ # Get the actual _replace method from the tuple class
1517+ actual_replace_method = getattr (self .tuple_cls , "_replace" , None )
1518+ if actual_replace_method :
1519+ from ..source import AttrSource
1520+
1521+ source = AttrSource (self .source , name ) if self .source else None
1522+ return variables .GetAttrVariable (self , name , source = source )
1523+ # Fallback if _replace doesn't exist (shouldn't happen for proper NamedTuples)
1524+ return super ().var_getattr (tx , name )
14531525
1454- def var_getattr (self , tx : "InstructionTranslator" , name : str ) -> "VariableTracker" :
14551526 if name == "_fields" :
1456- source = NamedTupleFieldsSource (self .source ) if self .source else None
1457- return VariableTracker .build (tx , self .fields (), source = source )
1527+ result_source = NamedTupleFieldsSource (self .source ) if self .source else None
1528+ return VariableTracker .build (tx , self .fields (), source = result_source )
14581529
14591530 if name in self .dynamic_attributes :
14601531 return self .dynamic_attributes [name ]
14611532
14621533 fields = self .fields ()
1463- if name in fields :
1464- field_index = fields .index (name )
1465- return self ._tuple_vt .items [field_index ]
1534+ if name not in fields :
1535+ method = check_and_create_method ()
1536+ if not method :
1537+ return super ().var_getattr (tx , name )
1538+ return method
1539+ return self .items [fields .index (name )]
14661540
1467- return super ().var_getattr (tx , name )
1541+ def call_obj_hasattr (
1542+ self , tx : "InstructionTranslator" , name : str
1543+ ) -> VariableTracker :
1544+ return variables .ConstantVariable .create (
1545+ name in self .dynamic_attributes or hasattr (self .tuple_cls , name )
1546+ )
14681547
14691548
14701549class SliceVariable (VariableTracker ):
0 commit comments