@@ -44,7 +44,7 @@ def _(tile: torch.SymInt) -> torch.Tensor:
44
44
45
45
@_decorators .codegen (tile_index )
46
46
def _ (state : CodegenState ) -> ast .AST :
47
- index = _get_tile_index (state )
47
+ index = _disable_flatten_get_tile (state . proxy_arg ( 0 ) )
48
48
return expr_from_string (state .codegen .index_var (index ))
49
49
50
50
@@ -59,25 +59,24 @@ def tile_begin(tile: Tile) -> int:
59
59
60
60
@_decorators .register_fake (tile_begin )
61
61
def _ (tile : torch .SymInt ) -> torch .SymInt :
62
+ _disable_flatten_get_tile (tile ) # update config spec if needed
62
63
return CompileEnvironment .current ().create_unbacked_symint ()
63
64
64
65
65
- def _get_tile_index ( state : CodegenState , disable_flatten : bool = True ) -> int :
66
+ def _disable_flatten_get_tile ( tile : object ) -> int :
66
67
"""Helper to extract tile index from state."""
67
- tile = state .proxy_arg (0 )
68
- assert isinstance (tile , torch .SymInt )
68
+ assert isinstance (tile , torch .SymInt ), (type (type ), tile )
69
69
env = CompileEnvironment .current ()
70
70
index = env .get_block_id (tile )
71
71
assert index is not None
72
- if disable_flatten :
73
- # The functions in this file can't be used in flattened loops.
74
- env .config_spec .flatten_loops .disable_block_id (index )
72
+ # The functions in this file can't be used in flattened loops.
73
+ env .config_spec .flatten_loops .disable_block_id (index )
75
74
return index
76
75
77
76
78
77
@_decorators .codegen (tile_begin )
79
78
def _ (state : CodegenState ) -> ast .AST :
80
- index = _get_tile_index (state )
79
+ index = _disable_flatten_get_tile (state . proxy_arg ( 0 ) )
81
80
return expr_from_string (state .codegen .offset_var (index ))
82
81
83
82
@@ -94,12 +93,13 @@ def tile_end(tile: Tile) -> int:
94
93
95
94
@_decorators .register_fake (tile_end )
96
95
def _ (tile : torch .SymInt ) -> torch .SymInt :
96
+ _disable_flatten_get_tile (tile ) # update config spec if needed
97
97
return CompileEnvironment .current ().create_unbacked_symint ()
98
98
99
99
100
100
@_decorators .codegen (tile_end )
101
101
def _ (state : CodegenState ) -> ast .AST :
102
- index = _get_tile_index (state )
102
+ index = _disable_flatten_get_tile (state . proxy_arg ( 0 ) )
103
103
offset_var = state .codegen .offset_var (index )
104
104
block_size_var = state .device_function .block_size_var (index )
105
105
if block_size_var is None :
0 commit comments