Skip to content

Commit d8d0372

Browse files
authored
Fix issue with ConfigSpec mutation in codegen (#195)
Fixes #185
1 parent 169b0d7 commit d8d0372

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

helion/language/tile_ops.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _(tile: torch.SymInt) -> torch.Tensor:
4444

4545
@_decorators.codegen(tile_index)
4646
def _(state: CodegenState) -> ast.AST:
47-
index = _get_tile_index(state)
47+
index = _disable_flatten_get_tile(state.proxy_arg(0))
4848
return expr_from_string(state.codegen.index_var(index))
4949

5050

@@ -59,25 +59,24 @@ def tile_begin(tile: Tile) -> int:
5959

6060
@_decorators.register_fake(tile_begin)
6161
def _(tile: torch.SymInt) -> torch.SymInt:
62+
_disable_flatten_get_tile(tile) # update config spec if needed
6263
return CompileEnvironment.current().create_unbacked_symint()
6364

6465

65-
def _get_tile_index(state: CodegenState, disable_flatten: bool = True) -> int:
66+
def _disable_flatten_get_tile(tile: object) -> int:
6667
"""Helper to extract tile index from state."""
67-
tile = state.proxy_arg(0)
68-
assert isinstance(tile, torch.SymInt)
68+
assert isinstance(tile, torch.SymInt), (type(type), tile)
6969
env = CompileEnvironment.current()
7070
index = env.get_block_id(tile)
7171
assert index is not None
72-
if disable_flatten:
73-
# The functions in this file can't be used in flattened loops.
74-
env.config_spec.flatten_loops.disable_block_id(index)
72+
# The functions in this file can't be used in flattened loops.
73+
env.config_spec.flatten_loops.disable_block_id(index)
7574
return index
7675

7776

7877
@_decorators.codegen(tile_begin)
7978
def _(state: CodegenState) -> ast.AST:
80-
index = _get_tile_index(state)
79+
index = _disable_flatten_get_tile(state.proxy_arg(0))
8180
return expr_from_string(state.codegen.offset_var(index))
8281

8382

@@ -94,12 +93,13 @@ def tile_end(tile: Tile) -> int:
9493

9594
@_decorators.register_fake(tile_end)
9695
def _(tile: torch.SymInt) -> torch.SymInt:
96+
_disable_flatten_get_tile(tile) # update config spec if needed
9797
return CompileEnvironment.current().create_unbacked_symint()
9898

9999

100100
@_decorators.codegen(tile_end)
101101
def _(state: CodegenState) -> ast.AST:
102-
index = _get_tile_index(state)
102+
index = _disable_flatten_get_tile(state.proxy_arg(0))
103103
offset_var = state.codegen.offset_var(index)
104104
block_size_var = state.device_function.block_size_var(index)
105105
if block_size_var is None:

test/test_misc.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,20 @@ def _kernel_make_precompiler(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass
252252
return make_precompiler(_kernel_kernel)(a0, o0, o1, a0.size(0), a0.stride(0), o0.stride(0), o1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
253253
)
254254

255+
def test_config_flatten_issue(self):
256+
@helion.kernel(use_default_config=True)
257+
def test_tile_id_atomic_add(x: torch.Tensor) -> torch.Tensor:
258+
out = torch.zeros_like(x, dtype=torch.int32)
259+
for tile_m, tile_n in hl.tile(x.size()):
260+
out[tile_m.begin, tile_n.begin] = 1
261+
return out
262+
263+
x = torch.randn(64, 64, device="cuda")
264+
config = helion.Config(block_sizes=[16, 16])
265+
test_tile_id_atomic_add.bind((x,)).to_triton_code(config)
266+
result = test_tile_id_atomic_add.bind((x,)).compile_config(config)(x)
267+
self.assertEqual(result.sum().item(), 16)
268+
255269

256270
if __name__ == "__main__":
257271
unittest.main()

0 commit comments

Comments
 (0)