Skip to content

Commit e955627

Browse files
authored
[TK] Add a TestLaunchContext for generating test dispatch IRs (#429)
This patch changes the way we do grid bindings. It is now expected that the IndexingContext will have the workload values and the grid on construction will try to build it's dims using them. Currently, all symbols are assumed to have a constant value, so the grid is able to map workgroup calculations to constants. Ideally, we would have support to parameterize the kernel by some symbols, which would be passed as workload bindings (We have no support for this currently).
1 parent 26d6428 commit e955627

File tree

13 files changed

+249
-262
lines changed

13 files changed

+249
-262
lines changed

core/shark_turbine/kernel/_support/indexing.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,6 @@ def __new__(
104104
new_class.__qualname__ = repr(new_class)
105105
return new_class
106106

107-
def __class_getitem__(
108-
cls, symbolic_shape: Union[SymbolicDimable, tuple[SymbolicShapeable]]
109-
) -> Type["Grid"]:
110-
if not isinstance(symbolic_shape, tuple):
111-
symbolic_shape = (symbolic_shape,)
112-
return cast(Grid, _make_shaped_grid(cls, make_symbolic_shape(symbolic_shape)))
113-
114107
def __repr__(self):
115108
if self.symbolic_shape:
116109
return f"Grid[{', '.join(repr(s) for s in self.symbolic_shape)}]"
@@ -122,20 +115,31 @@ class Grid(metaclass=_GridMeta, symbolic_shape=None):
122115
"""Grid with bounding symbolic shape information in the type."""
123116

124117
symbolic_shape: ClassVar[Optional[SymbolicShapeExpr]]
118+
# TODO: dims should also allow dynamic dimensions.
119+
dims: list[int]
125120
rank: int
126121

127-
def __init__(self, *dims: int):
128-
rank = len(dims)
129-
if self.symbolic_shape is not None:
130-
if rank != len(self.symbolic_shape):
131-
raise ValueError(
132-
f"Cannot create {type(self)}({', '.join(str(i) for i in dims)}): mismatched symbolic rank"
133-
)
122+
def __init__(self):
123+
# Resolve the symbolic shape to concrete values.
124+
idxc = IndexingContext.current()
125+
if self.symbolic_shape:
126+
dims = [idxc.get_static_value(dim) for dim in self.symbolic_shape]
127+
if None in dims:
128+
raise ValueError(f"NYI: Dynamic dims in Grid")
129+
self.dims = cast(list[int], dims)
130+
else:
131+
self.dims = []
134132

135-
self.dims = dims
136133
# Shadow the type rank with the actual, which makes it concrete
137134
# for the generic case.
138-
self.rank = rank
135+
self.rank = len(self.dims)
136+
137+
def __class_getitem__(
138+
cls, symbolic_shape: Union[SymbolicDimable, tuple[SymbolicShapeable]]
139+
) -> Type["Grid"]:
140+
if not isinstance(symbolic_shape, tuple):
141+
symbolic_shape = (symbolic_shape,)
142+
return cast(Grid, _make_shaped_grid(cls, make_symbolic_shape(symbolic_shape)))
139143

140144
def __repr__(self):
141145
return f"{repr(type(self))}({', '.join(str(i) for i in self.dims)})"
@@ -161,6 +165,8 @@ class ShapedGrid(Grid, symbolic_shape=symbolic_shape):
161165
# KernelBuffer
162166
###############################################################################
163167

168+
Dims = list[Union[None, IndexSymbol, int]]
169+
164170

165171
class KernelBufferUsage(Enum):
166172
NONE = 0
@@ -331,7 +337,6 @@ class TemporaryBuffer(KernelBuffer):
331337
###############################################################################
332338

333339
ShapedType = Union[Type[KernelBuffer], Type[Grid]]
334-
Dims = list[Union[None, IndexSymbol, int]]
335340

336341

337342
@dataclass(slots=True)

core/shark_turbine/kernel/_support/tracing.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
backed_sym_index_type,
2424
BoundedRelation,
2525
IndexExpr,
26+
IndexSymbol,
2627
Grid,
2728
KernelBuffer,
2829
SymIndex,
30+
IndexingContext,
2931
)
3032

3133
from ..lang.types import (
@@ -388,10 +390,16 @@ def __call__(self, *args, **kwargs):
388390
def eager_execute(self, args, kwargs):
389391
...
390392

393+
def test_execute(self, args, kwargs):
394+
...
395+
391396

392397
class LaunchContext(ABC):
393398
__tk_context_idname__ = "ExecutionContext"
394399

400+
def __init__(self, constant_bindings: Dict[IndexSymbol, int] = {}):
401+
self.constant_bindings = constant_bindings
402+
395403
@staticmethod
396404
def current() -> "LaunchContext":
397405
try:
@@ -404,9 +412,21 @@ def current() -> "LaunchContext":
404412
return DebugLaunchContext()
405413

406414
def __enter__(self) -> "LaunchContext":
415+
# Push an indexing context with the constand bindings for this launch
416+
# context in it.
417+
# TODO: Is creating a IndexingContext as part of LaunchContext the
418+
# correct layering?
419+
idxc = IndexingContext()
420+
context.push(IndexingContext, idxc)
421+
for s, val in self.constant_bindings.items():
422+
idxc.bind_constant(s, val)
407423
return context.push(LaunchContext, self)
408424

409425
def __exit__(self, exc_type, exc_val, exc_tb):
426+
# Pop the indexing context created as part of this launch.
427+
# TODO: Is creating a IndexingContext as part of LaunchContext the
428+
# correct layering?
429+
context.pop(IndexingContext, IndexingContext().current())
410430
context.pop(LaunchContext, self)
411431

412432
@abstractmethod
@@ -419,6 +439,11 @@ def launch(self, launchable: Launchable, args, kwargs):
419439
return launchable.eager_execute(args, kwargs)
420440

421441

442+
class TestLaunchContext(LaunchContext):
443+
def launch(self, launchable: Launchable, args, kwargs):
444+
return launchable.test_execute(args, kwargs)
445+
446+
422447
###############################################################################
423448
# Helpers
424449
###############################################################################

core/shark_turbine/kernel/compiler/dispatch_codegen.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
embedding and generating the calls/dispatches.
66
"""
77

8-
from typing import Any, Callable, Optional
8+
from typing import Any, Callable, Optional, Type
99

1010
from .._support.indexing import (
1111
IndexingContext,
@@ -44,6 +44,8 @@
4444
KernelSignature,
4545
)
4646

47+
from .._support.indexing import Grid
48+
4749

4850
class StreamExecutable:
4951
"""Encapsulates a 'stream' compilable executable which can be dispatched to.
@@ -93,6 +95,7 @@ def define_entrypoint(
9395
self,
9496
name: str,
9597
sig: KernelSignature,
98+
grid: Grid,
9699
) -> "DispatchEntrypoint":
97100
"""Defines a dispatch function with a signature like:
98101
@@ -105,7 +108,7 @@ def define_entrypoint(
105108
Also adds an export with workgroup function like:
106109
107110
```
108-
stream.executable.export public @name(%workload0 : index, %workload1 : index) -> (index, [[grid_arity...]]) {
111+
stream.executable.export private @name(%workload0 : index, %workload1 : index) -> (index, [[grid_arity...]]) {
109112
110113
}
111114
```
@@ -115,28 +118,32 @@ def define_entrypoint(
115118
kb_input_bindings = sig.kernel_buffer_input_bindings
116119
kb_temp_bindings = sig.kernel_buffer_temporary_bindings
117120
kb_output_bindings = sig.kernel_buffer_output_bindings
118-
# TODO: The way we are doing grid bindings is wrong. The Grid type should be paramerized
119-
# with special grid axis symbols which are algebraically related to concrete shape dim
120-
# symbols. For now, we are just treating the grid symbol as the input and output to the
121-
# workload function, when in reality, the workload needs to derive from its leaf inputs.
122-
grid_axis_bindings = sig.grid_bindings
121+
# TODO: The way we are doing grid bindings is wrong. The Grid type
122+
# should be paramerized with special grid axis symbols which are
123+
# algebraically related to concrete shape dim symbols. For now, we are
124+
# just assuming that the grid dims can be resolved to constants , when
125+
# in reality, we should pass the workload and parameterize the grid
126+
# dims on the workloads.
127+
workload_axis_bindings = []
123128

124129
# Input bindings are always user specified.
125130
# Grid/workgroup bindings are in the inputs section but are implied.
126131
# Temp bindings are a special kind of output bindings.
127132
# Output bindings are the real outputs.
128133
linear_bindings = (
129134
kb_input_bindings
130-
+ grid_axis_bindings
135+
+ workload_axis_bindings
131136
+ kb_temp_bindings
132137
+ kb_output_bindings
133138
)
134139

135-
# TODO: This is sloppy. This assert will hit on some user errors for unsupported
136-
# type combinations and is just a last resort right now.
137-
assert len(linear_bindings) == len(
138-
sig.bindings
139-
), f"Not all bindings converted: {linear_bindings} vs {sig.bindings}"
140+
# TODO: This is sloppy. This assert will hit on some user errors for
141+
# unsupported type combinations and is just a last resort right now.
142+
# TODO: This is currently disabled because the grid_bindings don't match
143+
# workload bindings.
144+
# assert len(linear_bindings) == len(
145+
# sig.bindings
146+
# ), f"Not all bindings converted: {linear_bindings} vs {sig.bindings}"
140147

141148
with self._loc:
142149
binding_type = IrType.parse("!stream.binding")
@@ -161,17 +168,22 @@ def abi_type(binding: BindingDesc):
161168
with InsertionPoint.at_block_begin(self._exe_block):
162169
export_op = stream_d.ExecutableExportOp(name, name)
163170
export_block = export_op.workgroup_count.blocks.append(
164-
*([b.as_mlir_type() for b in grid_axis_bindings])
171+
*([b.as_mlir_type() for b in workload_axis_bindings])
165172
)
166173

167-
# TODO: Reify actual workload calculation.
168174
workgroup_builder = WorkgroupBuilder(
169175
export_block, lambda vs: stream_d.ReturnOp(vs)
170176
)
171-
workgroup_values = list(workgroup_builder.workload)
172-
while len(workgroup_values) < 3:
173-
with InsertionPoint(workgroup_builder.entry_block):
174-
result_type = IndexType.get()
177+
178+
# TODO: Support passing workload to the dispatch function.
179+
with InsertionPoint(workgroup_builder.entry_block):
180+
result_type = IndexType.get()
181+
workgroup_values = [
182+
arith_d.constant(result_type, IntegerAttr.get(result_type, dim))
183+
for dim in grid.dims
184+
]
185+
186+
while len(workgroup_values) < 3:
175187
workgroup_values.append(
176188
arith_d.constant(result_type, IntegerAttr.get(result_type, 1))
177189
)
@@ -220,8 +232,7 @@ def __init__(
220232
def resolve(self, binding: BindingDesc) -> Value:
221233
ref_type, ref_value = binding.reference
222234
if ref_type == "grid":
223-
# TODO: Switch to stream op when #15889 is landed.
224-
return flow_d.dispatch_workgroup_id(
235+
return stream_d.dispatch_workgroup_id(
225236
IntegerAttr.get(IndexType.get(), ref_value)
226237
)
227238

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from .kernel_codegen import KernelSignature
2+
from .dispatch_codegen import StreamExecutable
3+
4+
from .builder import (
5+
ModuleBuilder,
6+
)
7+
8+
from .ir import (
9+
Block,
10+
FunctionType,
11+
InsertionPoint,
12+
IrType,
13+
Location,
14+
ArrayAttr,
15+
SymbolRefAttr,
16+
MemRefType,
17+
RankedTensorType,
18+
flow_d,
19+
func_d,
20+
)
21+
22+
23+
def memref_to_tensor(memrefs: list[IrType]):
24+
tensors = []
25+
for m in memrefs:
26+
assert isinstance(m, MemRefType)
27+
t = RankedTensorType.get(m.shape, m.element_type)
28+
tensors.append(t)
29+
return tensors
30+
31+
32+
def isolated_test_call(
33+
mb: ModuleBuilder, exe: StreamExecutable, sig: KernelSignature, entrypoint: str
34+
):
35+
with InsertionPoint(mb.body_block), Location.unknown():
36+
input_types = [b.as_mlir_type() for b in sig.kernel_buffer_input_bindings]
37+
input_tensors = memref_to_tensor(input_types)
38+
output_types = [b.as_mlir_type() for b in sig.kernel_buffer_output_bindings]
39+
output_tensors = memref_to_tensor(output_types)
40+
41+
ftype = FunctionType.get(input_tensors, output_tensors)
42+
func_op = func_d.FuncOp("isolated_benchmark", ftype)
43+
arg_locs = [
44+
(Location.name(b.name) if b.name is not None else Location.unknown())
45+
for b in sig.kernel_buffer_input_bindings
46+
]
47+
entry_block = func_op.add_entry_block(arg_locs)
48+
with InsertionPoint(entry_block):
49+
assert isinstance(entry_block, Block)
50+
# Create a flow.dispatch op to the kernel
51+
dispatch = SymbolRefAttr.get([exe.sym_name.value, entrypoint])
52+
entrypoints = ArrayAttr.get([dispatch])
53+
54+
out = flow_d.DispatchOp(
55+
output_tensors, [], entrypoints, entry_block.arguments, [], []
56+
)
57+
58+
func_d.ReturnOp(out)

core/shark_turbine/kernel/compiler/ir.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
AffineConstantExpr,
33
AffineExpr,
44
AffineMap,
5+
FlatSymbolRefAttr,
6+
SymbolRefAttr,
57
AffineMapAttr,
68
Attribute,
9+
RankedTensorType,
710
ArrayAttr,
811
Block,
912
Context,

core/shark_turbine/kernel/compiler/kernel_codegen.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def grid_bindings(self) -> list[BindingDesc]:
118118
@property
119119
def kernel_buffer_input_bindings(self) -> list[BindingDesc]:
120120
"""Gets all kernel buffer bindings with input usage."""
121-
print("ALL=", self.bindings)
122121
return [
123122
b
124123
for b in self.bindings
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from .thread import *
2+
3+
from .._support.tracing import TestLaunchContext

0 commit comments

Comments
 (0)