Skip to content

Commit 51bf969

Browse files
authored
[Wave] Enable async kernel execution with iree runtime and fully switch to Launchable (#11)
Drop our `invoke_vmfb` code entirely, use `turbine.runtime.Lauchable` and udate dispatch codegen to support async launch. Before: ``` func.func @isolated_benchmark(%arg0: tensor<8x?x128x6xf32>, %arg1: tensor<8x?x6xf32>, %arg2: tensor<?xi32>, %arg3: tensor<?x6x128xf16>, %arg4: index) -> tensor<?x6x128xf16> { %0 = flow.dispatch @phase_1::@phase_1[%arg4](%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<8x?x128x6xf32>{%arg4}, tensor<8x?x6xf32>{%arg4}, tensor<?xi32>{%arg4}, tensor<?x6x128xf16>{%arg4}, index) -> %arg3{%arg4} return %0 : tensor<?x6x128xf16> } ``` After: ``` func.func @isolated_benchmark$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: index, %arg5: !hal.fence, %arg6: !hal.fence) -> !hal.buffer_view { %0 = hal.tensor.import wait(%arg5) => %arg0 : !hal.buffer_view -> tensor<8x?x128x6xf32>{%arg4} %1 = hal.tensor.import wait(%arg5) => %arg1 : !hal.buffer_view -> tensor<8x?x6xf32>{%arg4} %2 = hal.tensor.import wait(%arg5) => %arg2 : !hal.buffer_view -> tensor<?xi32>{%arg4} %3 = hal.tensor.import wait(%arg5) => %arg3 : !hal.buffer_view -> tensor<?x6x128xf16>{%arg4} %4 = flow.dispatch @phase_1::@phase_1[%arg4](%0, %1, %2, %3, %arg4) : (tensor<8x?x128x6xf32>{%arg4}, tensor<8x?x6xf32>{%arg4}, tensor<?xi32>{%arg4}, tensor<?x6x128xf16>{%arg4}, index) -> %3{%arg4} %5 = hal.tensor.barrier join(%4 : tensor<?x6x128xf16>) => %arg6 : !hal.fence %6 = hal.tensor.export %5 : tensor<?x6x128xf16>{%arg4} -> !hal.buffer_view return %6 : !hal.buffer_view } ``` Also, add some python profiling code to `WaveKernel` launch func. Launch overhead for the iree runtime is still bad (around x5 of `wave_runtime`) but ability to run kernels async is overall improvement. --------- Signed-off-by: Ivan Butygin <[email protected]>
1 parent 7050dc6 commit 51bf969

File tree

10 files changed

+207
-172
lines changed

10 files changed

+207
-172
lines changed

lit_tests/kernel/wave/codegen.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2167,6 +2167,8 @@ def scalar_codegen_f32(
21672167
scalar_codegen_f32 = wave_compile(options, scalar_codegen_f32)
21682168
print(scalar_codegen_f32.asm)
21692169

2170+
# CHECK-LABEL: test_scalar_codegen_f32
2171+
21702172
# Passed scalars' dtype
21712173
# CHECK: func.func @scalar_codegen_f32(
21722174
# CHECK-SAME: %arg2: f32, %arg3: f32)
@@ -2177,8 +2179,11 @@ def scalar_codegen_f32(
21772179
# CHECK: arith.addf
21782180

21792181
# Final dispatch args dtype
2182+
# CHECK: func.func @isolated_benchmark$async(%[[ARG0:.*]]: !hal.buffer_view, %[[ARG1:.*]]: !hal.buffer_view, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32
2183+
# CHECK: %[[V0:.*]] = hal.tensor.import wait(%{{.*}}) => %[[ARG0]]
2184+
# CHECK: %[[V1:.*]] = hal.tensor.import wait(%{{.*}}) => %[[ARG1]]
21802185
# CHECK: flow.dispatch @scalar_codegen_f32::@scalar_codegen_f32(
2181-
# CHECK-SAME: %arg0, %arg1, %arg2, %arg3)
2186+
# CHECK-SAME: %[[V0]], %[[V1]], %[[ARG2]], %[[ARG3]])
21822187

21832188

21842189
@run_test
@@ -2220,6 +2225,8 @@ def scalar_codegen_i32(
22202225
scalar_codegen_i32 = wave_compile(options, scalar_codegen_i32)
22212226
print(scalar_codegen_i32.asm)
22222227

2228+
# CHECK-LABEL: test_scalar_codegen_i32
2229+
22232230
# Passed scalars' dtype: i32
22242231
# CHECK: func.func @scalar_codegen_i32(
22252232
# CHECK-SAME: %arg2: i32, %arg3: i32)
@@ -2230,8 +2237,11 @@ def scalar_codegen_i32(
22302237
# CHECK: arith.addi
22312238

22322239
# Final dispatch args dtype
2240+
# CHECK: func.func @isolated_benchmark$async(%[[ARG0:.*]]: !hal.buffer_view, %[[ARG1:.*]]: !hal.buffer_view, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32
2241+
# CHECK: %[[V0:.*]] = hal.tensor.import wait(%{{.*}}) => %[[ARG0]]
2242+
# CHECK: %[[V1:.*]] = hal.tensor.import wait(%{{.*}}) => %[[ARG1]]
22332243
# CHECK: flow.dispatch @scalar_codegen_i32::@scalar_codegen_i32(
2234-
# CHECK-SAME: %arg0, %arg1, %arg2, %arg3)
2244+
# CHECK-SAME: %[[V0]], %[[V1]], %[[ARG2]], %[[ARG3]])
22352245

22362246

22372247
# This kernel copies of data from a into b if tid.x < threshold.

lit_tests/kernel/wave/location.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def add_loc_local_scope(
7171
# CHECK: vector.load {{.*}} loc("{{.*}}location.py":{{[0-9]+}}
7272
# CHECK: arith.addf {{.*}} loc("{{.*}}location.py":{{[0-9]+}}
7373
#
74-
# CHECK: @isolated_benchmark(%{{.*}} loc("a"("{{.*}}location.py":{{[0-9]+}}{{.*}} loc("b"("{{.*}}location.py":{{[0-9]+}}
74+
# CHECK: @isolated_benchmark$async(%{{.*}} loc("a"("{{.*}}location.py":{{[0-9]+}}{{.*}} loc("b"("{{.*}}location.py":{{[0-9]+}}
7575

7676

7777
@run_test
@@ -98,7 +98,7 @@ def add_loc_global_scope(
9898
# CHECK-LABEL: @add_loc_global_scope
9999
# CHECK: vector.load {{.*}} loc(#[[loc_load:.+]])
100100
# CHECK: arith.addf {{.*}} loc(#[[loc_addf:.+]])
101-
# CHECK: @isolated_benchmark(%{{.*}} loc("a"(#[[loc_arg]])), %{{.*}} loc("b"(#[[loc_arg]])))
101+
# CHECK: @isolated_benchmark$async(%{{.*}} loc("a"(#[[loc_arg]])), %{{.*}} loc("b"(#[[loc_arg]])), %{{.*}} loc(unknown), %{{.*}} loc(unknown))
102102
# CHECK-DAG: #[[loc_load]] = loc("{{.*}}location.py":{{[0-9]+}}
103103
# CHECK-DAG: #[[loc_addf]] = loc("{{.*}}location.py":{{[0-9]+}}
104104

lit_tests/kernel/wave/sharktank_integration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
165165
func_name=wave_kernel_name,
166166
compile_to_mlir=True,
167167
canonicalize=False,
168+
iree_launch_async=False,
168169
)
169170
options = set_default_run_config(options)
170171
with Context() as ctx:

wave_lang/kernel/compiler/host_codegen.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
arith_d,
2020
flow_d,
2121
func_d,
22+
hal_d,
23+
tensor_d,
2224
)
2325

2426
from .._support.indexing import IndexSymbol
@@ -31,7 +33,10 @@
3133
from .kernel_codegen import BindingDesc, KernelSignature
3234

3335

34-
def memref_to_tensor(memrefs: list[IrType]):
36+
def memref_to_tensor(memrefs: list[IrType], use_views: bool = False):
37+
if use_views:
38+
view_type = IrType.parse("!hal.buffer_view")
39+
3540
tensors = []
3641
for m in memrefs:
3742
# append scalars as-it-is to tensors list
@@ -41,7 +46,7 @@ def memref_to_tensor(memrefs: list[IrType]):
4146
tensors.append(m)
4247
continue
4348
assert isinstance(m, MemRefType)
44-
t = RankedTensorType.get(m.shape, m.element_type)
49+
t = view_type if use_views else RankedTensorType.get(m.shape, m.element_type)
4550
tensors.append(t)
4651
return tensors
4752

@@ -79,12 +84,14 @@ def isolated_test_call(
7984
dynamic_symbols: list[IndexSymbol] = [],
8085
*,
8186
location_capture_config: Optional[LocationCaptureConfig] = None,
87+
async_dispatch: bool = False,
8288
):
8389
with InsertionPoint(mb.body_block), Location.unknown():
8490
input_types = [b.as_mlir_type() for b in sig.kernel_buffer_bindings] + [
8591
b.as_mlir_type() for b in sig.scalar_bindings
8692
]
87-
input_tensors = memref_to_tensor(input_types)
93+
94+
input_tensors = memref_to_tensor(input_types, use_views=async_dispatch)
8895
argument_dims = get_dynamic_dims(sig.kernel_buffer_bindings, dynamic_symbols)
8996
# Adding unique dynamic dims as inputs.
9097
input_tensors += [IndexType.get() for _ in list(dict.fromkeys(argument_dims))]
@@ -93,8 +100,13 @@ def isolated_test_call(
93100
IndexType.get() for _ in set(dynamic_symbols).difference(argument_dims)
94101
]
95102

103+
if async_dispatch:
104+
fence_type = IrType.parse("!hal.fence")
105+
input_tensors += [fence_type] * 2
106+
func_name = func_name + "$async"
107+
96108
output_types = [b.as_mlir_type() for b in sig.kernel_buffer_output_bindings]
97-
output_tensors = memref_to_tensor(output_types)
109+
output_tensors = memref_to_tensor(output_types, use_views=async_dispatch)
98110
result_dims = get_dynamic_dims(
99111
sig.kernel_buffer_output_bindings, dynamic_symbols
100112
)
@@ -110,13 +122,39 @@ def isolated_test_call(
110122
+ scalar_bindings
111123
+ sig.dynamic_dim_bindings
112124
]
125+
if async_dispatch:
126+
arg_locs += [Location.unknown()] * 2
127+
113128
entry_block = func_op.add_entry_block(arg_locs)
114129
scalars_offset = len(sig.kernel_buffer_bindings)
115130
scalars_count = len(scalar_bindings)
116131
dynamic_offset = scalars_offset + scalars_count
117132

118133
with InsertionPoint(entry_block):
119134
arguments = entry_block.arguments
135+
if async_dispatch:
136+
in_fence = arguments[-2]
137+
out_fence = arguments[-1]
138+
arguments = list(arguments[:-2])
139+
140+
for i, b in enumerate(sig.kernel_buffer_bindings):
141+
shape = b.kernel_buffer_type.symbolic_shape
142+
143+
arg = arguments[i]
144+
arg_type = memref_to_tensor([b.as_mlir_type()])[0]
145+
target_dims = [
146+
hal_d.buffer_view_dim(arg, d)
147+
for d in range(len(shape))
148+
if arg_type.is_dynamic_dim(d)
149+
]
150+
arguments[i] = hal_d.tensor_import(
151+
arg_type,
152+
arg,
153+
wait_fence=in_fence,
154+
target_encoding=arg_type,
155+
target_dims=target_dims,
156+
)
157+
120158
scalars_args = [
121159
to_index(v)
122160
for v, b in zip(
@@ -142,13 +180,36 @@ def isolated_test_call(
142180
)
143181

144182
out = flow_d.DispatchOp(
145-
output_tensors,
183+
memref_to_tensor(output_types), # output_tensors,
146184
[dynamic_argument_map[dim] for dim in dynamic_symbols] + scalars_args,
147185
entrypoints,
148-
entry_block.arguments,
186+
arguments,
149187
[dynamic_argument_map[dim] for dim in argument_dims],
150188
[dynamic_argument_map[dim] for dim in result_dims],
151189
tied_operands=tied_operands,
152190
)
153191

192+
if async_dispatch:
193+
out = list(out.results)
194+
out_types = memref_to_tensor(
195+
[b.as_mlir_type() for b in sig.kernel_buffer_output_bindings]
196+
)
197+
barrier = hal_d.tensor_barrier(out_types, out, signal_fence=out_fence)
198+
if len(out_types) == 1:
199+
barrier = [barrier]
200+
201+
view_type = IrType.parse("!hal.buffer_view")
202+
for i, b in enumerate(sig.kernel_buffer_output_bindings):
203+
shape = b.kernel_buffer_type.symbolic_shape
204+
205+
out_type = out_types[i]
206+
source_dims = [
207+
tensor_d.dim(out[i], arith_d.constant(IndexType.get(), d))
208+
for d in range(len(shape))
209+
if out_type.is_dynamic_dim(d)
210+
]
211+
out[i] = hal_d.tensor_export(
212+
view_type, barrier[i], out_type, source_dims=source_dims
213+
)
214+
154215
func_d.ReturnOp(out)

wave_lang/kernel/wave/cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def get_hash(
199199
options.optimization_level,
200200
options.denorm_fp_math_f32,
201201
options.waves_per_eu,
202+
options.iree_launch_async,
202203
options.use_buffer_load_ops,
203204
options.use_buffer_store_ops,
204205
options.use_stride_cache_swizzle,

0 commit comments

Comments
 (0)