@@ -15,7 +15,6 @@ class that handles its unique behaviors while integrating with Dynamo's
1515"""
1616
1717import collections
18- import inspect
1918import operator
2019import sys
2120from collections .abc import Sequence
@@ -39,7 +38,6 @@ class that handles its unique behaviors while integrating with Dynamo's
3938 get_fake_value ,
4039 guard_if_dyn ,
4140 iter_contains ,
42- Lit ,
4341 namedtuple_fields ,
4442 odict_values ,
4543 raise_args_mismatch ,
@@ -48,8 +46,8 @@ class that handles its unique behaviors while integrating with Dynamo's
4846)
4947from .base import ValueMutationNew , VariableTracker
5048from .constant import ConstantVariable
51- from .functions import UserFunctionVariable , UserMethodVariable
5249from .iter import IteratorVariable
50+ from .user_defined import UserDefinedTupleVariable
5351
5452
5553if TYPE_CHECKING :
@@ -1296,24 +1294,51 @@ def call_obj_hasattr(
12961294 return variables .ConstantVariable .create (hasattr (torch .Size , name ))
12971295
12981296
1299- class NamedTupleVariable (TupleVariable ):
1297+ class NamedTupleVariable (UserDefinedTupleVariable ):
13001298 _nonvar_fields = {
13011299 "tuple_cls" ,
13021300 "dynamic_attributes" ,
1303- * TupleVariable ._nonvar_fields ,
1301+ * UserDefinedTupleVariable ._nonvar_fields ,
13041302 }
13051303
13061304 def __init__ (
13071305 self ,
13081306 items : list [VariableTracker ],
1309- tuple_cls : type ,
1307+ tuple_cls : type [ tuple ] ,
13101308 dynamic_attributes : Optional [dict [str , VariableTracker ]] = None ,
13111309 ** kwargs : Any ,
13121310 ) -> None :
1313- super ().__init__ (items , ** kwargs )
1311+ tuple_vt = variables .TupleVariable (
1312+ items , mutation_type = kwargs .get ("mutation_type" , ValueMutationNew ())
1313+ )
1314+
1315+ # Create a dummy instance for method resolution
1316+ # This allows _maybe_get_baseclass_method to work correctly
1317+ fields = namedtuple_fields (tuple_cls )
1318+ num_fields = len (fields )
1319+ if tuple_cls .__module__ == "torch.return_types" :
1320+ # Structseq: single iterable argument
1321+ dummy_value = tuple_cls ([None ] * num_fields )
1322+ else :
1323+ # Namedtuple: positional arguments
1324+ dummy_value = tuple_cls (* ([None ] * num_fields )) # type: ignore[arg-type]
1325+
1326+ super ().__init__ (
1327+ value = dummy_value ,
1328+ tuple_vt = tuple_vt ,
1329+ init_args = None ,
1330+ ** kwargs ,
1331+ )
1332+
13141333 self .tuple_cls = tuple_cls
1334+ if len (self .tuple_cls .__mro__ ) < 3 :
1335+ raise ValueError ("NamedTuple should inherit from Tuple and Object." )
13151336 self .dynamic_attributes = dynamic_attributes if dynamic_attributes else {}
13161337
1338+ @property
1339+ def items (self ) -> list [VariableTracker ]:
1340+ return self ._tuple_vt .items
1341+
13171342 def is_namedtuple (self ) -> bool :
13181343 return isinstance (getattr (self .tuple_cls , "_fields" , None ), tuple ) and callable (
13191344 getattr (self .tuple_cls , "_make" , None )
@@ -1325,17 +1350,7 @@ def is_structseq(self) -> bool:
13251350 def fields (self ) -> tuple [str , ...]:
13261351 return namedtuple_fields (self .tuple_cls )
13271352
1328- def debug_repr (self ) -> str :
1329- if self .is_structseq ():
1330- # StructSequenceType(iterable)
1331- return repr (self .tuple_cls ([Lit (x .debug_repr ()) for x in self .items ]))
1332- # NamedTupleType(*iterable)
1333- return repr (self .tuple_cls (* (Lit (x .debug_repr ()) for x in self .items )))
1334-
1335- def python_type (self ) -> type :
1336- return self .tuple_cls
1337-
1338- def as_python_constant (self ) -> Any :
1353+ def as_python_constant (self ):
13391354 if self .is_structseq ():
13401355 # StructSequenceType(iterable)
13411356 result = self .python_type ()([x .as_python_constant () for x in self .items ])
@@ -1357,57 +1372,39 @@ def as_python_constant(self) -> Any:
13571372
13581373 return result
13591374
1360- def as_proxy (self ) -> Any :
1361- assert self .python_type () is not SizeVariable
1375+ def as_proxy (self ):
13621376 if self .is_structseq ():
1363- # StructSequenceType(iterable)
1364- return self .python_type ()(self ._as_proxy ())
1365- # NamedTupleType(*iterable)
1366- return self .python_type ()(* self ._as_proxy ())
1377+ return self .python_type ()([x .as_proxy () for x in self ._tuple_vt .items ])
1378+ return self .python_type ()(* [x .as_proxy () for x in self ._tuple_vt .items ])
13671379
13681380 def reconstruct (self , codegen : "PyCodegen" ) -> None :
1369- # Always reconstruct the NamedTuple normally first
1370- # Constructors:
1371- # StructSequenceType(iterable)
1372- # NamedTupleType(*iterable)
1373- # NamedTupleType._make(iterable)
13741381 if self .is_structseq ():
13751382 create_fn = self .tuple_cls
13761383 else :
13771384 create_fn = self .tuple_cls ._make # type: ignore[attr-defined]
1385+
13781386 codegen .add_push_null (
13791387 lambda : codegen .append_output (
13801388 codegen .create_load_const_unchecked (create_fn )
13811389 )
13821390 )
1383- codegen .foreach (self .items )
1391+ codegen .foreach (self ._tuple_vt . items )
13841392 codegen .extend_output (
13851393 [
1386- create_build_tuple (len (self .items )),
1394+ create_build_tuple (len (self ._tuple_vt . items )),
13871395 ]
13881396 + create_call_function (1 , False )
13891397 )
13901398
1399+ # Apply initial dynamic attributes after construction (if any)
1400+ # Runtime dynamic attributes are tracked via side effects system
13911401 for name , value in self .dynamic_attributes .items ():
13921402 codegen .dup_top ()
13931403 codegen (value )
13941404 codegen .extend_output (create_rot_n (2 ))
13951405 codegen .store_attr (name )
13961406
13971407 def _is_method_overridden (self , method_name : str ) -> bool :
1398- """Checks if a method is overridden in the NamedTuple subclass.
1399-
1400- Args:
1401- method_name (str): The name of the method to check.
1402-
1403- Returns:
1404- bool: True if the method is overridden in the subclass, False otherwise.
1405-
1406- Raises:
1407- ValueError: If the NamedTuple class does not inherit from both Tuple and Object.
1408- """
1409- if len (self .tuple_cls .__mro__ ) < 3 :
1410- raise ValueError ("NamedTuple should inherit from Tuple and Object." )
14111408 if getattr (self .tuple_cls , method_name , None ) == getattr (
14121409 self .tuple_cls .__mro__ [- 3 ], method_name , None
14131410 ):
@@ -1421,129 +1418,53 @@ def call_method(
14211418 args : list [VariableTracker ],
14221419 kwargs : dict [str , VariableTracker ],
14231420 ) -> VariableTracker :
1424- if name == "__setattr__" :
1421+ if self ._is_method_overridden (name ):
1422+ # Fall back to UserDefinedTupleVariable
1423+ return super ().call_method (tx , name , args , kwargs )
1424+ elif name == "__setattr__" :
14251425 if kwargs or len (args ) != 2 :
14261426 raise_args_mismatch (
14271427 tx ,
14281428 name ,
14291429 "2 args and 0 kwargs" ,
14301430 f"{ len (args )} args and { len (kwargs )} kwargs" ,
14311431 )
1432- attr , value = args
1433- attr = attr .as_python_constant ()
1432+ attr_var , value = args
1433+ attr = attr_var .as_python_constant ()
1434+
14341435 if (
14351436 # structseq is immutable
14361437 self .is_structseq ()
14371438 # namedtuple directly created by `collections.namedtuple` is immutable
14381439 or self .tuple_cls .__bases__ == (tuple ,)
1439- # fields are immutable
14401440 or attr in self .fields ()
14411441 ):
14421442 raise_observed_exception (AttributeError , tx )
1443- # Subclass of namedtuple type can have dynamic attributes
1444- tx .output .side_effects .mutation (self )
1445- if self .source :
1446- tx .output .side_effects .store_attr (self , attr , value )
1447- self .dynamic_attributes [attr ] = value
1448- return ConstantVariable .create (None )
1449- elif name == "_replace" :
1450- # NamedTuple._replace should create a new instance with replaced fields
1451- if args :
1452- raise_args_mismatch (tx , name , "0 args" , f"{ len (args )} args" )
1453-
1454- # Get the field names for validation
1455- fields = self .fields ()
1456-
1457- # Start with current items (copy them)
1458- new_items = list (self .items )
1459-
1460- # Replace fields specified in kwargs
1461- for field_name , new_value in kwargs .items ():
1462- if field_name not in fields :
1463- raise_observed_exception (
1464- ValueError ,
1465- tx ,
1466- args = [
1467- ConstantVariable .create (
1468- f"Got unexpected field name: '{ field_name } '"
1469- )
1470- ],
1471- )
1472-
1473- # Replace the item at the field's index
1474- field_index = fields .index (field_name )
1475- new_items [field_index ] = new_value
14761443
1477- return NamedTupleVariable (new_items , self .tuple_cls )
1444+ result = self .method_setattr_standard (tx , attr_var , value )
1445+ # Also update self.dynamic_attributes
1446+ self .dynamic_attributes [attr ] = value
1447+ return result
14781448
14791449 return super ().call_method (tx , name , args , kwargs )
14801450
1481- def getitem_const (
1482- self , tx : "InstructionTranslator" , arg : VariableTracker
1483- ) -> VariableTracker :
1484- if isinstance (arg , SliceVariable ):
1485- # slicing a namedtuple produces a tuple
1486- return TupleVariable (
1487- self .items [arg .as_python_constant ()],
1488- source = None ,
1489- )
1490- return super ().getitem_const (tx , arg )
1491-
1492- def var_getattr (self , tx : "InstructionTranslator" , name : str ) -> VariableTracker :
1493- def check_and_create_method () -> Optional [VariableTracker ]:
1494- method = inspect .getattr_static (self .tuple_cls , name , None )
1495- if isinstance (method , classmethod ):
1496- # We need the unbounded cls method to avoid the inline __self__
1497- return UserMethodVariable (
1498- method .__func__ ,
1499- variables .UserDefinedClassVariable (self .tuple_cls ),
1500- )
1501- elif isinstance (method , staticmethod ):
1502- # pyrefly: ignore[bad-argument-type]
1503- return UserFunctionVariable (method .__func__ )
1504- elif inspect .isfunction (method ):
1505- return UserMethodVariable (method , self )
1506- else :
1507- return None
1508-
1509- # Avoid UserMethodVariable fallback precisely when methods NamedTuple methods have not been overwritten.
1510- if (
1511- name == "_replace"
1512- and not self ._is_method_overridden ("_replace" )
1513- and not self ._is_method_overridden ("__getattr__" )
1514- ):
1515- # Return a BuiltinVariable for the _replace method
1516- # Get the actual _replace method from the tuple class
1517- actual_replace_method = getattr (self .tuple_cls , "_replace" , None )
1518- if actual_replace_method :
1519- from ..source import AttrSource
1520-
1521- source = AttrSource (self .source , name ) if self .source else None
1522- return variables .GetAttrVariable (self , name , source = source )
1523- # Fallback if _replace doesn't exist (shouldn't happen for proper NamedTuples)
1524- return super ().var_getattr (tx , name )
1451+ def python_type (self ) -> type :
1452+ return self .tuple_cls
15251453
1454+ def var_getattr (self , tx : "InstructionTranslator" , name : str ) -> "VariableTracker" :
15261455 if name == "_fields" :
1527- result_source = NamedTupleFieldsSource (self .source ) if self .source else None
1528- return VariableTracker .build (tx , self .fields (), source = result_source )
1456+ source = NamedTupleFieldsSource (self .source ) if self .source else None
1457+ return VariableTracker .build (tx , self .fields (), source = source )
15291458
15301459 if name in self .dynamic_attributes :
15311460 return self .dynamic_attributes [name ]
15321461
15331462 fields = self .fields ()
1534- if name not in fields :
1535- method = check_and_create_method ()
1536- if not method :
1537- return super ().var_getattr (tx , name )
1538- return method
1539- return self .items [fields .index (name )]
1463+ if name in fields :
1464+ field_index = fields .index (name )
1465+ return self ._tuple_vt .items [field_index ]
15401466
1541- def call_obj_hasattr (
1542- self , tx : "InstructionTranslator" , name : str
1543- ) -> VariableTracker :
1544- return variables .ConstantVariable .create (
1545- name in self .dynamic_attributes or hasattr (self .tuple_cls , name )
1546- )
1467+ return super ().var_getattr (tx , name )
15471468
15481469
15491470class SliceVariable (VariableTracker ):
0 commit comments