|
20 | 20 | import torch.fx as fx |
21 | 21 |
|
22 | 22 | from .indexing import ( |
23 | | - BoundedSymbolicValue, |
| 23 | + backed_sym_index_type, |
| 24 | + BoundedRelation, |
| 25 | + IndexExpr, |
24 | 26 | Grid, |
25 | 27 | KernelBuffer, |
26 | | - sym_0, |
| 28 | + SymIndex, |
27 | 29 | ) |
28 | 30 |
|
29 | 31 | from ..lang.types import ( |
@@ -98,10 +100,17 @@ class KernelTracer(SubgraphTracer): |
98 | 100 | # Register our custom proxies. |
99 | 101 | def proxy(self, node: fx.Node) -> fx.Proxy: |
100 | 102 | t = node.type |
101 | | - if t is not None and issubclass(t, KernelBuffer): |
102 | | - return KernelBufferProxy(node, self, t) |
| 103 | + if t is not None: |
| 104 | + if issubclass(t, KernelBuffer): |
| 105 | + return KernelBufferProxy(node, self, t) |
103 | 106 | return super().proxy(node) |
104 | 107 |
|
| 108 | + def create_arg(self, a): |
| 109 | + # Let IndexExpr persist as arguments. |
| 110 | + if isinstance(a, IndexExpr): |
| 111 | + return a |
| 112 | + return super().create_arg(a) |
| 113 | + |
105 | 114 |
|
106 | 115 | class CapturedTrace: |
107 | 116 | def __init__(self, region_graph: RegionGraph, root_graph: str): |
@@ -163,23 +172,28 @@ def __init__(self, region_graph: RegionGraph, *, grid_type: Type[Grid]): |
163 | 172 | super().__init__(eager=False) |
164 | 173 | self.region_graph = region_graph |
165 | 174 | self.grid_type = grid_type |
| 175 | + self.current_thread_types = [ |
| 176 | + backed_sym_index_type(BoundedRelation(0, n, upper_inclusive=False)) |
| 177 | + for n in grid_type.symbolic_shape |
| 178 | + ] |
166 | 179 |
|
167 | 180 | ### ======================================================================== |
168 | 181 | ### Core Operations |
169 | 182 | ### ======================================================================== |
170 | 183 |
|
171 | 184 | def handle_thread_program_id(self, op, axis: int) -> Index: |
172 | | - grid_shape = self.grid_type.symbolic_shape |
173 | | - if axis < 0 or axis >= len(grid_shape): |
| 185 | + grid_types = self.current_thread_types |
| 186 | + if axis < 0 or axis >= len(grid_types): |
174 | 187 | raise IndexError( |
175 | | - f"Illegal index into grid of rank {len(grid_shape)}: {axis}" |
| 188 | + f"Illegal index into grid of rank {len(grid_types)}: {axis}" |
176 | 189 | ) |
| 190 | + |
177 | 191 | proxy = self.region_graph.create_proxy( |
178 | 192 | "call_function", |
179 | 193 | op, |
180 | 194 | args=(axis,), |
181 | 195 | kwargs={}, |
182 | | - type_expr=BoundedSymbolicValue.bound(sym_0, grid_shape[axis]), |
| 196 | + type_expr=grid_types[axis], |
183 | 197 | ) |
184 | 198 | return proxy |
185 | 199 |
|
|
0 commit comments