Skip to content

Commit 78b9c7a

Browse files
authored
Add support for global statements in type propagation (#203)
1 parent 9082b57 commit 78b9c7a

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

helion/_compiler/type_propagation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,6 +1888,10 @@ def visit_Assert(self, node: ast.Assert) -> TypeInfo:
18881888
visit_Import: _VisitMethod = generic_statement
18891889
visit_ImportFrom: _VisitMethod = generic_statement
18901890

1891+
def visit_Global(self, node: ast.Global) -> TypeInfo:
1892+
# Global statements don't need child visiting since they only declare names
1893+
return NoType(origin=self.origin())
1894+
18911895
# TODO(jansel): support lambda
18921896
visit_Lambda: _VisitMethod = generic_visit
18931897

test/data/all_ast_nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def all_ast_nodes(x, y):
138138
with contextlib.nullcontext():
139139
e3 = 1
140140

141-
# global global0 # global statements not supported
141+
global global0
142142

143143
out = torch.empty_like(x)
144144
v = 0

test/test_type_propagation.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ def all_ast_nodes(x, y):
492492
t = 0
493493
with contextlib.nullcontext():
494494
e3 = 1
495+
global global0
495496
# Call: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation all_ast_nodes.py:143>)
496497
# Attribute: CallableType(_VariableFunctionsClass.empty_like) AttributeOrigin(value=GlobalOrigin(name='torch'), key='empty_like')
497498
# Name: PythonModuleType(torch) GlobalOrigin(name='torch')
@@ -771,32 +772,32 @@ def fn(x):
771772
output,
772773
"""\
773774
def fn(x):
774-
# Call: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation test_type_propagation.py:761>)
775+
# Call: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation test_type_propagation.py:762>)
775776
# Attribute: CallableType(_VariableFunctionsClass.empty_like) AttributeOrigin(value=GlobalOrigin(name='torch'), key='empty_like')
776777
# Name: PythonModuleType(torch) GlobalOrigin(name='torch')
777778
# Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x')
778779
# For: loop_type=GRID
779780
out = torch.empty_like(x)
780-
# Call: IterType(SequenceType([TileIndexType(0), TileIndexType(1)])) SourceOrigin(location=<SourceLocation test_type_propagation.py:762>)
781+
# Call: IterType(SequenceType([TileIndexType(0), TileIndexType(1)])) SourceOrigin(location=<SourceLocation test_type_propagation.py:763>)
781782
# Attribute: CallableType(tile) AttributeOrigin(value=GlobalOrigin(name='hl'), key='tile')
782783
# Name: PythonModuleType(helion.language) GlobalOrigin(name='hl')
783-
# Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_type_propagation.py:762>)
784+
# Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_type_propagation.py:763>)
784785
# Attribute: TensorAttributeType AttributeOrigin(value=ArgumentOrigin(name='x'), key='size')
785786
# Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x')
786787
for tile in hl.tile(x.size()):
787-
# Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:763>)
788-
# Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation test_type_propagation.py:761>)
789-
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation test_type_propagation.py:762>)
790-
# Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:763>)
791-
# Attribute: TensorAttributeType AttributeOrigin(value=DeviceOrigin(location=<SourceLocation test_type_propagation.py:763>), key='sin')
792-
# Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:763>)
788+
# Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:764>)
789+
# Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation test_type_propagation.py:762>)
790+
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation test_type_propagation.py:763>)
791+
# Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:764>)
792+
# Attribute: TensorAttributeType AttributeOrigin(value=DeviceOrigin(location=<SourceLocation test_type_propagation.py:764>), key='sin')
793+
# Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=<SourceLocation test_type_propagation.py:764>)
793794
# Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x')
794-
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation test_type_propagation.py:762>)
795+
# Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=<SourceLocation test_type_propagation.py:763>)
795796
out[tile] = x[tile].sin()
796797
return out
797798
798799
def root_graph_0():
799-
# File: .../test_type_propagation.py:763 in fn, code: out[tile] = x[tile].sin()
800+
# File: .../test_type_propagation.py:764 in fn, code: out[tile] = x[tile].sin()
800801
x: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
801802
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
802803
block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1')

0 commit comments

Comments
 (0)