Skip to content

Commit 457df0f

Browse files
authored
Add support for listcomp (#412)
1 parent 7d54ca4 commit 457df0f

File tree

6 files changed

+678
-15
lines changed

6 files changed

+678
-15
lines changed

helion/_compiler/device_ir.py

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def disable_tracing() -> Iterator[proxy_tensor.PythonKeyTracer]:
475475

476476
@staticmethod
477477
def should_become_arg(value: object) -> bool:
478-
if isinstance(value, (Tile, torch.SymInt)):
478+
if isinstance(value, (Tile, int, float, bool, type(None), torch.SymInt)):
479479
return False
480480
if isinstance(value, torch.Tensor):
481481
if (
@@ -502,20 +502,54 @@ def _extract_tile_begin_end(self, for_node: ast.For) -> tuple[object, object]:
502502
end = self.visit(args[1])
503503
return begin, end
504504

505+
def _handle_sequence_unrolling(
506+
self,
507+
sequence_iter: ast.AST,
508+
target: ast.AST,
509+
element_processor: Callable[[], object | None],
510+
preserve_scope: bool = False,
511+
) -> list[object]:
512+
"""Common logic for unrolling sequences in both loops and comprehensions."""
513+
# Get the sequence of values to iterate over
514+
sequence_value = self.visit(sequence_iter)
515+
assert isinstance(sequence_value, (tuple, list)), (
516+
f"Expected tuple or list, got {type(sequence_value)}"
517+
)
518+
519+
results = []
520+
for element_value in sequence_value:
521+
if preserve_scope:
522+
# For loops: don't create new scope, allow state to persist
523+
self._assign(target, element_value)
524+
result = element_processor()
525+
if result is not None:
526+
results.append(result)
527+
else:
528+
# For comprehensions: create isolated scope for each iteration
529+
old_scope = self.scope.copy()
530+
try:
531+
self._assign(target, element_value)
532+
result = element_processor()
533+
if result is not None:
534+
results.append(result)
535+
finally:
536+
self.scope = old_scope
537+
538+
return results
539+
505540
def _handle_tuple_unrolling(
506541
self,
507542
node: ast.For,
508543
) -> None:
509544
"""Handle unrolling of loops that iterate over tuples of tensors."""
510-
# Get the sequence of tensors to iterate over
511-
sequence_value = self.visit(node.iter)
512-
assert isinstance(sequence_value, (tuple, list)), (
513-
f"Expected tuple or list, got {type(sequence_value)}"
514-
)
515-
# Unroll the loop by executing the body for each tensor in the sequence
516-
for tensor_value in sequence_value:
517-
self._assign(node.target, tensor_value)
545+
546+
def execute_body() -> None:
518547
self._body(node.body)
548+
return None # No result to collect for loops
549+
550+
self._handle_sequence_unrolling(
551+
node.iter, node.target, execute_body, preserve_scope=True
552+
)
519553

520554
def visit_For(self, node: ast.For) -> None:
521555
assert isinstance(node, ExtendedAST)
@@ -528,6 +562,15 @@ def visit_For(self, node: ast.For) -> None:
528562
self._handle_tuple_unrolling(node)
529563
return
530564

565+
# Special handling for variables that might contain sequences from list comprehensions
566+
if isinstance(node.iter, ast.Name) and node.iter.id in self.scope:
567+
scope_value = self.scope[node.iter.id]
568+
if isinstance(scope_value, (tuple, list)):
569+
# This is a sequence in the scope, we should try to unroll it
570+
# even if the type info doesn't indicate it's a SequenceType
571+
self._handle_tuple_unrolling(node)
572+
return
573+
531574
if not isinstance(iter_type, IterType):
532575
raise exc.InvalidDeviceForLoop(iter_type)
533576
inner_type: TypeInfo = iter_type.inner
@@ -724,6 +767,50 @@ def visit_Tuple(self, node: ast.Tuple) -> tuple[object, ...]:
724767
def visit_List(self, node: ast.List) -> list[object]:
725768
return [self.visit(x) for x in node.elts]
726769

770+
def visit_ListComp(self, node: ast.ListComp) -> tuple[object, ...]:
771+
"""Handle list comprehension unrolling similar to tuple unrolling."""
772+
assert isinstance(node, ExtendedAST)
773+
774+
# Only handle simple cases with single generator and no if conditions
775+
if len(node.generators) != 1 or node.generators[0].ifs:
776+
raise exc.StatementNotSupported(
777+
"Complex list comprehensions are not supported"
778+
)
779+
780+
generator = node.generators[0]
781+
assert isinstance(generator.iter, ExtendedAST)
782+
iter_type = generator.iter._type_info
783+
784+
# Check if we're iterating over a sequence (similar to tuple unrolling)
785+
if isinstance(iter_type, SequenceType):
786+
return self._handle_listcomp_unrolling(node)
787+
788+
# For non-sequence iterables, we could extend this later
789+
raise exc.StatementNotSupported(
790+
"List comprehensions over non-sequence types are not supported"
791+
)
792+
793+
def _handle_listcomp_unrolling(self, node: ast.ListComp) -> tuple[object, ...]:
794+
"""Handle unrolling of list comprehensions over sequences."""
795+
generator = node.generators[0]
796+
797+
def evaluate_expression() -> object:
798+
# Evaluate the comprehension expression
799+
result = self.visit(node.elt)
800+
# If the result is a SymInt that can be evaluated to a concrete value, do so
801+
if isinstance(result, torch.SymInt):
802+
try:
803+
return int(result)
804+
except (ValueError, TypeError):
805+
return result
806+
return result
807+
808+
results = self._handle_sequence_unrolling(
809+
generator.iter, generator.target, evaluate_expression, preserve_scope=False
810+
)
811+
# Return as tuple to match the expected type for tuple unrolling
812+
return tuple(results)
813+
727814
def visit_Slice(self, node: ast.Slice) -> slice:
728815
if node.lower is None:
729816
lower = None

helion/_compiler/inductor_lowering.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,30 @@ def _default(
969969
result_str = _unpack_opsvalue(
970970
getattr(self.parent_handler, name)(*args, **kwargs)
971971
)
972+
973+
return self.cg.lift(expr_from_string(result_str)).id
974+
975+
def to_dtype(
976+
self,
977+
x: object,
978+
dtype: torch.dtype,
979+
src_dtype: torch.dtype | None = None,
980+
use_compute_types: bool = True,
981+
) -> str:
982+
"""Override to_dtype to use tl.cast for scalar values from GetItemOrigin."""
983+
x_str = str(x)
984+
985+
# Use tl.cast for scalar values (typically from GetItemOrigin)
986+
# These are plain scalars that should use tl.cast instead of .to()
987+
if "_item_" in x_str:
988+
return self.cg.lift(
989+
expr_from_string(f"tl.cast({x_str}, {triton_type(dtype)})")
990+
).id
991+
992+
# Fall back to the default behavior for other cases
993+
result_str = _unpack_opsvalue(
994+
self.parent_handler.to_dtype(x, dtype, src_dtype, use_compute_types)
995+
)
972996
return self.cg.lift(expr_from_string(result_str)).id
973997

974998
def load(self, name: str, index: sympy.Expr) -> str:

helion/_compiler/type_propagation.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,11 @@ def from_example(cls, value: object, origin: Origin) -> TypeInfo:
263263
)
264264
),
265265
)
266+
if isinstance(value, zip):
267+
# Handle zip objects by converting to tuple of tuples
268+
# This allows zip to work in list comprehensions
269+
zipped_tuples = tuple(tuple(items) for items in value)
270+
return cls.from_example(zipped_tuples, origin)
266271
raise exc.UnsupportedPythonType(type(value).__name__)
267272

268273
@staticmethod
@@ -2039,8 +2044,63 @@ def _not_on_device_statement(self, node: ast.AST) -> TypeInfo:
20392044
def _not_supported(self, node: ast.AST) -> TypeInfo:
20402045
raise exc.StatementNotSupported(type(node).__name__)
20412046

2047+
def _evaluate_comprehension(
2048+
self, generator: ast.comprehension, expression: ast.AST
2049+
) -> TypeInfo:
2050+
"""Helper method to evaluate comprehension type propagation."""
2051+
# Visit the iterable to get its type
2052+
iter_type = self.visit(generator.iter)
2053+
2054+
# Get element type and evaluate expression in scope
2055+
self.push_scope()
2056+
try:
2057+
element_type = iter_type.propagate_iter(self.origin())
2058+
self._assign(generator.target, element_type)
2059+
2060+
# Process conditional filters (basic validation)
2061+
for if_clause in generator.ifs:
2062+
self.visit(if_clause)
2063+
2064+
element_result_type = self.visit(expression)
2065+
finally:
2066+
self.pop_scope()
2067+
2068+
# Try to determine exact result by unpacking iterable
2069+
try:
2070+
iterable_elements = iter_type.unpack()
2071+
result_elements = []
2072+
2073+
for element_type in iterable_elements:
2074+
self.push_scope()
2075+
try:
2076+
self._assign(generator.target, element_type)
2077+
# For now, assume all conditions pass
2078+
for if_clause in generator.ifs:
2079+
self.visit(if_clause)
2080+
result_elements.append(self.visit(expression))
2081+
finally:
2082+
self.pop_scope()
2083+
2084+
# If there are conditions, we can't determine exact length
2085+
if generator.ifs and result_elements:
2086+
result_elements = [result_elements[0]]
2087+
2088+
return SequenceType(self.origin(), result_elements)
2089+
2090+
except NotImplementedError:
2091+
# Fallback to generic list type
2092+
return SequenceType(self.origin(), [element_result_type])
2093+
2094+
def visit_ListComp(self, node: ast.ListComp) -> TypeInfo:
2095+
"""Type propagation for list comprehensions."""
2096+
if len(node.generators) != 1:
2097+
raise exc.StatementNotSupported(
2098+
"List comprehensions with multiple generators are not supported"
2099+
)
2100+
2101+
return self._evaluate_comprehension(node.generators[0], node.elt)
2102+
20422103
# TODO(jansel): need to implement these
2043-
visit_ListComp: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride]
20442104
visit_SetComp: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride]
20452105
visit_GeneratorExp: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride]
20462106
visit_DictComp: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride]

test/test_misc.expected

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1,
190190
load_1 = tl.load(inp_tuple_item_1 + (indices_0[:, None] * inp_tuple_item_1_stride_0 + indices_1[None, :] * inp_tuple_item_1_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
191191
v_0 = load_1.to(tl.float32)
192192
v_1 = load + v_0
193-
v_2 = inp_tuple_item_2.to(tl.float32)
193+
v_2 = tl.cast(inp_tuple_item_2, tl.float32)
194194
v_3 = v_1 * v_2
195195
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])
196196

@@ -217,7 +217,7 @@ def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1,
217217
load_1 = tl.load(tl.make_block_ptr(inp_tuple_item_1, [inp_tuple_item_1_size_0, inp_tuple_item_1_size_1], [inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
218218
v_0 = load_1.to(tl.float32)
219219
v_1 = load + v_0
220-
v_2 = inp_tuple_item_2.to(tl.float32)
220+
v_2 = tl.cast(inp_tuple_item_2, tl.float32)
221221
v_3 = v_1 * v_2
222222
tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1], [out_stride_0, out_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_3, boundary_check=[0, 1])
223223

@@ -255,7 +255,7 @@ def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1,
255255
load_1 = inp_tuple_item_1_desc.load([offset_0, offset_1])
256256
v_0 = load_1.to(tl.float32)
257257
v_1 = load + v_0
258-
v_2 = inp_tuple_item_2.to(tl.float32)
258+
v_2 = tl.cast(inp_tuple_item_2, tl.float32)
259259
v_3 = v_1 * v_2
260260
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])
261261

0 commit comments

Comments
 (0)