@@ -475,7 +475,7 @@ def disable_tracing() -> Iterator[proxy_tensor.PythonKeyTracer]:
475
475
476
476
@staticmethod
477
477
def should_become_arg (value : object ) -> bool :
478
- if isinstance (value , (Tile , torch .SymInt )):
478
+ if isinstance (value , (Tile , int , float , bool , type ( None ), torch .SymInt )):
479
479
return False
480
480
if isinstance (value , torch .Tensor ):
481
481
if (
@@ -502,20 +502,54 @@ def _extract_tile_begin_end(self, for_node: ast.For) -> tuple[object, object]:
502
502
end = self .visit (args [1 ])
503
503
return begin , end
504
504
505
+ def _handle_sequence_unrolling (
506
+ self ,
507
+ sequence_iter : ast .AST ,
508
+ target : ast .AST ,
509
+ element_processor : Callable [[], object | None ],
510
+ preserve_scope : bool = False ,
511
+ ) -> list [object ]:
512
+ """Common logic for unrolling sequences in both loops and comprehensions."""
513
+ # Get the sequence of values to iterate over
514
+ sequence_value = self .visit (sequence_iter )
515
+ assert isinstance (sequence_value , (tuple , list )), (
516
+ f"Expected tuple or list, got { type (sequence_value )} "
517
+ )
518
+
519
+ results = []
520
+ for element_value in sequence_value :
521
+ if preserve_scope :
522
+ # For loops: don't create new scope, allow state to persist
523
+ self ._assign (target , element_value )
524
+ result = element_processor ()
525
+ if result is not None :
526
+ results .append (result )
527
+ else :
528
+ # For comprehensions: create isolated scope for each iteration
529
+ old_scope = self .scope .copy ()
530
+ try :
531
+ self ._assign (target , element_value )
532
+ result = element_processor ()
533
+ if result is not None :
534
+ results .append (result )
535
+ finally :
536
+ self .scope = old_scope
537
+
538
+ return results
539
+
505
540
def _handle_tuple_unrolling (
506
541
self ,
507
542
node : ast .For ,
508
543
) -> None :
509
544
"""Handle unrolling of loops that iterate over tuples of tensors."""
510
- # Get the sequence of tensors to iterate over
511
- sequence_value = self .visit (node .iter )
512
- assert isinstance (sequence_value , (tuple , list )), (
513
- f"Expected tuple or list, got { type (sequence_value )} "
514
- )
515
- # Unroll the loop by executing the body for each tensor in the sequence
516
- for tensor_value in sequence_value :
517
- self ._assign (node .target , tensor_value )
545
+
546
+ def execute_body () -> None :
518
547
self ._body (node .body )
548
+ return None # No result to collect for loops
549
+
550
+ self ._handle_sequence_unrolling (
551
+ node .iter , node .target , execute_body , preserve_scope = True
552
+ )
519
553
520
554
def visit_For (self , node : ast .For ) -> None :
521
555
assert isinstance (node , ExtendedAST )
@@ -528,6 +562,15 @@ def visit_For(self, node: ast.For) -> None:
528
562
self ._handle_tuple_unrolling (node )
529
563
return
530
564
565
+ # Special handling for variables that might contain sequences from list comprehensions
566
+ if isinstance (node .iter , ast .Name ) and node .iter .id in self .scope :
567
+ scope_value = self .scope [node .iter .id ]
568
+ if isinstance (scope_value , (tuple , list )):
569
+ # This is a sequence in the scope, we should try to unroll it
570
+ # even if the type info doesn't indicate it's a SequenceType
571
+ self ._handle_tuple_unrolling (node )
572
+ return
573
+
531
574
if not isinstance (iter_type , IterType ):
532
575
raise exc .InvalidDeviceForLoop (iter_type )
533
576
inner_type : TypeInfo = iter_type .inner
@@ -724,6 +767,50 @@ def visit_Tuple(self, node: ast.Tuple) -> tuple[object, ...]:
724
767
def visit_List (self , node : ast .List ) -> list [object ]:
725
768
return [self .visit (x ) for x in node .elts ]
726
769
770
+ def visit_ListComp (self , node : ast .ListComp ) -> tuple [object , ...]:
771
+ """Handle list comprehension unrolling similar to tuple unrolling."""
772
+ assert isinstance (node , ExtendedAST )
773
+
774
+ # Only handle simple cases with single generator and no if conditions
775
+ if len (node .generators ) != 1 or node .generators [0 ].ifs :
776
+ raise exc .StatementNotSupported (
777
+ "Complex list comprehensions are not supported"
778
+ )
779
+
780
+ generator = node .generators [0 ]
781
+ assert isinstance (generator .iter , ExtendedAST )
782
+ iter_type = generator .iter ._type_info
783
+
784
+ # Check if we're iterating over a sequence (similar to tuple unrolling)
785
+ if isinstance (iter_type , SequenceType ):
786
+ return self ._handle_listcomp_unrolling (node )
787
+
788
+ # For non-sequence iterables, we could extend this later
789
+ raise exc .StatementNotSupported (
790
+ "List comprehensions over non-sequence types are not supported"
791
+ )
792
+
793
+ def _handle_listcomp_unrolling (self , node : ast .ListComp ) -> tuple [object , ...]:
794
+ """Handle unrolling of list comprehensions over sequences."""
795
+ generator = node .generators [0 ]
796
+
797
+ def evaluate_expression () -> object :
798
+ # Evaluate the comprehension expression
799
+ result = self .visit (node .elt )
800
+ # If the result is a SymInt that can be evaluated to a concrete value, do so
801
+ if isinstance (result , torch .SymInt ):
802
+ try :
803
+ return int (result )
804
+ except (ValueError , TypeError ):
805
+ return result
806
+ return result
807
+
808
+ results = self ._handle_sequence_unrolling (
809
+ generator .iter , generator .target , evaluate_expression , preserve_scope = False
810
+ )
811
+ # Return as tuple to match the expected type for tuple unrolling
812
+ return tuple (results )
813
+
727
814
def visit_Slice (self , node : ast .Slice ) -> slice :
728
815
if node .lower is None :
729
816
lower = None
0 commit comments