Skip to content

Commit 6ef2c5e

Browse files
authored
Update workgroup count op syntax (#21656)
The current syntax for workgroup count ops has a trailing operand list, which is problematic when the following op has results, e.g.: ```mlir %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice %foo = arith.constant 1 : index ``` gets parsed as: ```mlir %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice %foo = arith.constant 1 : index ``` and errors out. So far this hasn't been an issue, as the workgroup count ops are typically followed by a `return`, which has no results, but with the introduction of `workgroup_count_split_reduction_modifier` we've begun hitting this case more often. This updates the syntax to have additional literals which remove the ambiguity. Syntax change: Before: ```mlir iree_tensor_ext.dispatch.workgroup_count_from_slice iree_tensor_ext.dispatch.workgroup_count_split_reduction_modifier(%x, %y, %z), iree_tensor_ext.dispatch.workgroup_count_from_dag_root ``` After: ```mlir iree_tensor_ext.dispatch.workgroup_count_from_slice() iree_tensor_ext.dispatch.workgroup_count_split_reduction_modifier workgroups(%x, %y, %z) workload() iree_tensor_ext.dispatch.workgroup_count_from_dag_root() ```
1 parent 718ec6b commit 6ef2c5e

File tree

70 files changed

+323
-311
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+323
-311
lines changed

compiler/plugins/target/CUDA/test/smoketest.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ module attributes {
1111

1212
stream.executable public @add_dispatch_executable {
1313
stream.executable.export @add_dispatch workgroups(%arg0 : index) -> (index, index, index) {
14-
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root %arg0
14+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg0)
1515
stream.return %x, %y, %z : index, index, index
1616
}
1717
builtin.module {
@@ -36,7 +36,7 @@ stream.executable public @add_dispatch_executable {
3636

3737
stream.executable public @mul_dispatch_executable {
3838
stream.executable.export @mul_dispatch workgroups(%arg0 : index) -> (index, index, index) {
39-
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root %arg0
39+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg0)
4040
stream.return %x, %y, %z : index, index, index
4141
}
4242
builtin.module {

compiler/plugins/target/LLVMCPU/test/smoketest_embedded.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ module attributes {
1313

1414
stream.executable public @add_dispatch_0 {
1515
stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) {
16-
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root %arg0
16+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg0)
1717
stream.return %x, %y, %z : index, index, index
1818
}
1919
builtin.module {

compiler/plugins/target/LLVMCPU/test/smoketest_system.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ module attributes {
1515

1616
stream.executable public @add_dispatch_0 {
1717
stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) {
18-
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root %arg0
18+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg0)
1919
stream.return %x, %y, %z : index, index, index
2020
}
2121
builtin.module {

compiler/plugins/target/MetalSPIRV/test/smoketest.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ module attributes {
1515

1616
stream.executable public @reduce_dispatch {
1717
stream.executable.export @reduce_dispatch workgroups(%arg0 : index) -> (index, index, index) {
18-
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root %arg0
18+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg0)
1919
stream.return %x, %y, %z : index, index, index
2020
}
2121
builtin.module {

compiler/plugins/target/ROCM/builtins/tuning/test/spec_gfx942.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ hal.executable public @main {
3131
iree_codegen.default_tuning_spec = #rocm.builtin.tuning_module<"iree_default_tuning_spec_gfx942.mlir">
3232
}>) {
3333
hal.executable.export public @matmul_transpose_b ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
34-
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice
34+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
3535
hal.return %x, %y, %z : index, index, index
3636
}
3737
builtin.module {
@@ -82,7 +82,7 @@ hal.executable public @main {
8282
iree_codegen.default_tuning_spec = #rocm.builtin.tuning_module<"iree_default_tuning_spec_gfx942.mlir">
8383
}>) {
8484
hal.executable.export public @attention ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
85-
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice
85+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
8686
hal.return %x, %y, %z : index, index, index
8787
}
8888
builtin.module {
@@ -138,7 +138,7 @@ hal.executable public @main {
138138
iree_codegen.default_tuning_spec = #rocm.builtin.tuning_module<"iree_default_tuning_spec_gfx942.mlir">
139139
}>) {
140140
hal.executable.export public @attention ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
141-
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice
141+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
142142
hal.return %x, %y, %z : index, index, index
143143
}
144144
builtin.module {

compiler/plugins/target/ROCM/test/gpu_encoding_attrs.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
stream.executable public @main {
2323
stream.executable.export @main workgroups(%arg0: index) -> (index, index, index) {
24-
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root %arg0
24+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg0)
2525
stream.return %x, %y, %z : index, index, index
2626
}
2727
builtin.module {

compiler/plugins/target/ROCM/test/lowering_strategy_from_tuning_spec.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
hal.executable public @main {
2929
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
3030
hal.executable.export public @matmul_transpose_b ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
31-
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice
31+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
3232
hal.return %x, %y, %z : index, index, index
3333
}
3434
builtin.module {

compiler/plugins/target/ROCM/test/smoketest.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ module attributes {
1010

1111
stream.executable public @add_dispatch_executable {
1212
stream.executable.export @add_dispatch workgroups(%arg0 : index) -> (index, index, index) {
13-
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root %arg0
13+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg0)
1414
stream.return %x, %y, %z : index, index, index
1515
}
1616
builtin.module {
@@ -35,7 +35,7 @@ stream.executable public @add_dispatch_executable {
3535

3636
stream.executable public @mul_dispatch_executable {
3737
stream.executable.export @mul_dispatch workgroups(%arg0 : index) -> (index, index, index) {
38-
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root %arg0
38+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg0)
3939
stream.return %x, %y, %z : index, index, index
4040
}
4141
builtin.module {
@@ -80,7 +80,7 @@ module attributes {
8080

8181
stream.executable public @executable {
8282
stream.executable.export @export workgroups(%arg0 : index) -> (index, index, index) {
83-
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root %arg0
83+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg0)
8484
stream.return %x, %y, %z : index, index, index
8585
} loc(#loc)
8686
builtin.module {

compiler/plugins/target/ROCM/test/target_device_features.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383

8484
stream.executable public @reduce_dispatch {
8585
stream.executable.export @reduce_dispatch workgroups(%arg0: index) -> (index, index, index) {
86-
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root %arg0
86+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg0)
8787
stream.return %x, %y, %z : index, index, index
8888
}
8989
builtin.module {

compiler/plugins/target/VMVX/test/smoketest.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ module attributes {
1010

1111
stream.executable public @add_dispatch_0 {
1212
stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) {
13-
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root %arg0
13+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg0)
1414
stream.return %x, %y, %z : index, index, index
1515
}
1616
builtin.module {

0 commit comments

Comments
 (0)