Skip to content

Commit 0184eee

Browse files
[Codegen][RoCDL] Add patterns for lowering bit-width emulation operations to LLVM (#19551)
Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 76a7b89 commit 0184eee

File tree

2 files changed

+169
-80
lines changed

2 files changed

+169
-80
lines changed

compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,10 @@ struct ConvertToROCDLPass final
159159
populateDropSharedMemoryDeallocOpPatterns(patterns);
160160
populateScalarizeMathOps(patterns);
161161
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
162+
vector::populateBubbleVectorBitCastOpPatterns(patterns);
162163
vector::populateVectorBroadcastLoweringPatterns(patterns);
164+
vector::populateVectorInterleaveLoweringPatterns(patterns);
165+
vector::populateVectorInterleaveToShufflePatterns(patterns);
163166
vector::populateVectorContractLoweringPatterns(
164167
patterns,
165168
vector::VectorTransformsOptions().setVectorTransformsOptions(
@@ -225,6 +228,7 @@ struct ConvertToROCDLPass final
225228
vector::populateVectorRankReducingFMAPattern(llvmPatterns);
226229
vector::populateVectorInsertExtractStridedSliceTransforms(llvmPatterns);
227230
vector::populateVectorStepLoweringPatterns(llvmPatterns);
231+
vector::populateVectorBitCastLoweringPatterns(llvmPatterns);
228232
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
229233
vector::populateVectorTransferLoweringPatterns(llvmPatterns,
230234
/*maxTransferRank=*/1);
Lines changed: 165 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx908 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-convert-to-rocdl))))" %s | FileCheck %s
2-
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx908 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-convert-to-rocdl))))" --iree-hip-index-bits=32 %s | FileCheck %s --check-prefix=INDEX32
1+
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx908 --iree-convert-to-rocdl %s | FileCheck %s
2+
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx908 --iree-convert-to-rocdl --iree-hip-index-bits=32 %s | FileCheck %s --check-prefix=INDEX32
33

44
// Test that that standard and GPU ops are converted to LLVM and NVVM.
55
#pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -8,27 +8,22 @@
88
#hal.pipeline.binding<storage_buffer>,
99
#hal.pipeline.binding<storage_buffer>
1010
]>
11-
hal.executable @abs_ex_dispatch_0 {
12-
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
13-
hal.executable.export @abs_ex_dispatch_0 layout(#pipeline_layout)
14-
builtin.module {
15-
func.func @abs_ex_dispatch_0() {
16-
%c0 = arith.constant 0 : index
17-
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) flags(ReadOnly) : memref<16xf32>
18-
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : memref<16xf32>
19-
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : memref<16xf32>
20-
%3 = gpu.block_id x
21-
%4 = gpu.block_dim x
22-
%5 = gpu.thread_id x
23-
%6 = arith.muli %3, %4 : index
24-
%7 = arith.addi %6, %5 : index
25-
%9 = memref.load %1[%7] : memref<16xf32>
26-
%10 = memref.load %2[%7] : memref<16xf32>
27-
%11 = arith.addf %9, %10 : f32
28-
memref.store %11, %0[%7] : memref<16xf32>
29-
return
30-
}
31-
}
11+
builtin.module {
12+
func.func @abs_ex_dispatch_0() {
13+
%c0 = arith.constant 0 : index
14+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) flags(ReadOnly) : memref<16xf32>
15+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : memref<16xf32>
16+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : memref<16xf32>
17+
%3 = gpu.block_id x
18+
%4 = gpu.block_dim x
19+
%5 = gpu.thread_id x
20+
%6 = arith.muli %3, %4 : index
21+
%7 = arith.addi %6, %5 : index
22+
%9 = memref.load %1[%7] : memref<16xf32>
23+
%10 = memref.load %2[%7] : memref<16xf32>
24+
%11 = arith.addf %9, %10 : f32
25+
memref.store %11, %0[%7] : memref<16xf32>
26+
return
3227
}
3328
}
3429
// CHECK-LABEL: llvm.func @abs_ex_dispatch_0
@@ -49,23 +44,18 @@ hal.executable @abs_ex_dispatch_0 {
4944
#hal.pipeline.binding<storage_buffer>,
5045
#hal.pipeline.binding<storage_buffer>
5146
]>
52-
hal.executable @abs_ex_dispatch_0 {
53-
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
54-
hal.executable.export @abs_ex_dispatch_0 layout(#pipeline_layout)
55-
builtin.module {
56-
func.func @reduction_maximum() {
57-
%c0 = arith.constant 0 : index
58-
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) :
59-
memref<32x64x64xf32, strided<[4096, 64, 1], offset: ?>>
60-
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : memref<32x64x64xf32,
61-
strided<[4096, 64, 1], offset: ?>>
62-
%2 = vector.load %0[%c0, %c0, %c0] : memref<32x64x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<2xf32>
63-
%3 = vector.reduction <maximumf>, %2 : vector<2xf32> into f32
64-
%4 = vector.splat %3 : vector<2xf32>
65-
vector.store %4, %1[%c0, %c0, %c0] : memref<32x64x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<2xf32>
66-
return
67-
}
68-
}
47+
builtin.module {
48+
func.func @reduction_maximum() {
49+
%c0 = arith.constant 0 : index
50+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) :
51+
memref<32x64x64xf32, strided<[4096, 64, 1], offset: ?>>
52+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : memref<32x64x64xf32,
53+
strided<[4096, 64, 1], offset: ?>>
54+
%2 = vector.load %0[%c0, %c0, %c0] : memref<32x64x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<2xf32>
55+
%3 = vector.reduction <maximumf>, %2 : vector<2xf32> into f32
56+
%4 = vector.splat %3 : vector<2xf32>
57+
vector.store %4, %1[%c0, %c0, %c0] : memref<32x64x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<2xf32>
58+
return
6959
}
7060
}
7161
// CHECK-LABEL: llvm.func @reduction_maximum
@@ -76,15 +66,10 @@ hal.executable @abs_ex_dispatch_0 {
7666
#pipeline_layout = #hal.pipeline.layout<bindings = [
7767
#hal.pipeline.binding<storage_buffer>
7868
]>
79-
hal.executable @simple_barrier {
80-
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
81-
hal.executable.export @simple_barrier layout(#pipeline_layout)
82-
builtin.module {
83-
func.func @simple_barrier() {
84-
gpu.barrier
85-
return
86-
}
87-
}
69+
builtin.module {
70+
func.func @simple_barrier() {
71+
gpu.barrier
72+
return
8873
}
8974
}
9075
// CHECK-LABEL: llvm.func @simple_barrier
@@ -95,22 +80,17 @@ hal.executable @simple_barrier {
9580
#hal.pipeline.binding<storage_buffer>,
9681
#hal.pipeline.binding<storage_buffer>
9782
]>
98-
hal.executable @masked_load_store {
99-
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
100-
hal.executable.export @masked_load_store layout(#pipeline_layout)
101-
builtin.module {
102-
func.func @masked_load_store() {
103-
%c0 = arith.constant 0 : index
104-
%idx = gpu.thread_id x
105-
%pass_thru = arith.constant dense<0.000000e+00> : vector<1xf32>
106-
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : memref<64xf32, #gpu.address_space<global>>
107-
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : memref<64xf32, #gpu.address_space<global>>
108-
%mask = vector.create_mask %idx : vector<1xi1>
109-
%ld = vector.maskedload %0[%idx], %mask, %pass_thru : memref<64xf32, #gpu.address_space<global>>, vector<1xi1>, vector<1xf32> into vector<1xf32>
110-
vector.maskedstore %1[%idx], %mask, %ld : memref<64xf32, #gpu.address_space<global>>, vector<1xi1>, vector<1xf32>
111-
return
112-
}
113-
}
83+
builtin.module {
84+
func.func @masked_load_store() {
85+
%c0 = arith.constant 0 : index
86+
%idx = gpu.thread_id x
87+
%pass_thru = arith.constant dense<0.000000e+00> : vector<1xf32>
88+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : memref<64xf32, #gpu.address_space<global>>
89+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : memref<64xf32, #gpu.address_space<global>>
90+
%mask = vector.create_mask %idx : vector<1xi1>
91+
%ld = vector.maskedload %0[%idx], %mask, %pass_thru : memref<64xf32, #gpu.address_space<global>>, vector<1xi1>, vector<1xf32> into vector<1xf32>
92+
vector.maskedstore %1[%idx], %mask, %ld : memref<64xf32, #gpu.address_space<global>>, vector<1xi1>, vector<1xf32>
93+
return
11494
}
11595
}
11696
// CHECK-LABEL: llvm.func @masked_load_store
@@ -125,23 +105,128 @@ hal.executable @masked_load_store {
125105
#hal.pipeline.binding<storage_buffer>,
126106
#hal.pipeline.binding<storage_buffer>
127107
]>
128-
hal.executable private @interface_wg_size {
129-
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
130-
hal.executable.export @interface_wg_size layout(#pipeline_layout) attributes {
131-
workgroup_size = [32: index, 1: index, 1: index]
132-
}
133-
builtin.module attributes {} {
134-
func.func @interface_wg_size() {
135-
%c0 = arith.constant 0.0 : f32
136-
%workgroup_size_x = hal.interface.workgroup.size[0] : index
137-
%workgroup_size_y = hal.interface.workgroup.size[1] : index
138-
%subspan = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<64x64xf32>
139-
memref.store %c0, %subspan[%workgroup_size_x, %workgroup_size_y] : memref<64x64xf32>
140-
return
141-
}
142-
}
108+
builtin.module attributes {} {
109+
func.func @interface_wg_size() {
110+
%c0 = arith.constant 0.0 : f32
111+
%workgroup_size_x = hal.interface.workgroup.size[0] : index
112+
%workgroup_size_y = hal.interface.workgroup.size[1] : index
113+
%subspan = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<64x64xf32>
114+
memref.store %c0, %subspan[%workgroup_size_x, %workgroup_size_y] : memref<64x64xf32>
115+
return
143116
}
144117
}
145118
// CHECK-LABEL: llvm.func @interface_wg_size
146119
// CHECK: %[[WGDIMX:.+]] = rocdl.workgroup.dim.x
147120
// CHECK: %[[WGDIMY:.+]] = rocdl.workgroup.dim.y
121+
122+
// -----
123+
124+
// Check that the operations generated by emulate bit widths are lowered correctly
125+
126+
module {
127+
func.func @emulation_lowering() {
128+
%cst = arith.constant dense<4> : vector<2x4xi8>
129+
%cst_0 = arith.constant dense<15> : vector<2x4xi8>
130+
%c1 = arith.constant 1 : index
131+
%cst_1 = arith.constant dense<0> : vector<2x8xi4>
132+
%cst_2 = arith.constant dense<0.000000e+00> : vector<8x1x2xf16>
133+
%c32 = arith.constant 32 : index
134+
%c2 = arith.constant 2 : index
135+
%c8 = arith.constant 8 : index
136+
%c4 = arith.constant 4 : index
137+
%c0 = arith.constant 0 : index
138+
%thread_id_x = gpu.thread_id x upper_bound 64
139+
%0 = hal.interface.binding.subspan layout(<constants = 1, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : memref<131072x192xf16, #gpu.address_space<global>>
140+
memref.assume_alignment %0, 64 : memref<131072x192xf16, #gpu.address_space<global>>
141+
%1 = hal.interface.binding.subspan layout(<constants = 1, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : memref<131072x192xf16, #gpu.address_space<global>>
142+
memref.assume_alignment %1, 64 : memref<131072x192xf16, #gpu.address_space<global>>
143+
%2 = hal.interface.binding.subspan layout(<constants = 1, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : memref<402653184xi8, #gpu.address_space<global>>
144+
memref.assume_alignment %2, 64 : memref<402653184xi8, #gpu.address_space<global>>
145+
%3 = hal.interface.binding.subspan layout(<constants = 1, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(4) alignment(64) offset(%c0) flags(Indirect) : memref<131072x192x32xf16, #gpu.address_space<global>>
146+
memref.assume_alignment %3, 64 : memref<131072x192x32xf16, #gpu.address_space<global>>
147+
%4 = arith.divui %thread_id_x, %c4 : index
148+
%5 = arith.remui %thread_id_x, %c4 : index
149+
%6 = arith.muli %5, %c8 : index
150+
%workgroup_id_x = hal.interface.workgroup.id[0] upper_bound 6 : index
151+
%workgroup_id_y = hal.interface.workgroup.id[1] upper_bound 131072 : index
152+
%7 = arith.muli %4, %c2 : index
153+
%8 = arith.muli %workgroup_id_x, %c32 : index
154+
%9 = arith.addi %7, %8 : index
155+
%10 = vector.load %0[%workgroup_id_y, %9] : memref<131072x192xf16, #gpu.address_space<global>>, vector<2xf16>
156+
%11 = vector.broadcast %10 : vector<2xf16> to vector<1x2xf16>
157+
%12 = vector.insert %11, %cst_2 [0] : vector<1x2xf16> into vector<8x1x2xf16>
158+
%13 = vector.insert %11, %12 [1] : vector<1x2xf16> into vector<8x1x2xf16>
159+
%14 = vector.insert %11, %13 [2] : vector<1x2xf16> into vector<8x1x2xf16>
160+
%15 = vector.insert %11, %14 [3] : vector<1x2xf16> into vector<8x1x2xf16>
161+
%16 = vector.insert %11, %15 [4] : vector<1x2xf16> into vector<8x1x2xf16>
162+
%17 = vector.insert %11, %16 [5] : vector<1x2xf16> into vector<8x1x2xf16>
163+
%18 = vector.insert %11, %17 [6] : vector<1x2xf16> into vector<8x1x2xf16>
164+
%19 = vector.insert %11, %18 [7] : vector<1x2xf16> into vector<8x1x2xf16>
165+
%20 = vector.transpose %19, [1, 2, 0] : vector<8x1x2xf16> to vector<1x2x8xf16>
166+
%21 = vector.load %1[%workgroup_id_y, %9] : memref<131072x192xf16, #gpu.address_space<global>>, vector<2xf16>
167+
%22 = vector.broadcast %21 : vector<2xf16> to vector<1x2xf16>
168+
%23 = vector.insert %22, %cst_2 [0] : vector<1x2xf16> into vector<8x1x2xf16>
169+
%24 = vector.insert %22, %23 [1] : vector<1x2xf16> into vector<8x1x2xf16>
170+
%25 = vector.insert %22, %24 [2] : vector<1x2xf16> into vector<8x1x2xf16>
171+
%26 = vector.insert %22, %25 [3] : vector<1x2xf16> into vector<8x1x2xf16>
172+
%27 = vector.insert %22, %26 [4] : vector<1x2xf16> into vector<8x1x2xf16>
173+
%28 = vector.insert %22, %27 [5] : vector<1x2xf16> into vector<8x1x2xf16>
174+
%29 = vector.insert %22, %28 [6] : vector<1x2xf16> into vector<8x1x2xf16>
175+
%30 = vector.insert %22, %29 [7] : vector<1x2xf16> into vector<8x1x2xf16>
176+
%31 = vector.transpose %30, [1, 2, 0] : vector<8x1x2xf16> to vector<1x2x8xf16>
177+
%c3072 = arith.constant 3072 : index
178+
%32 = arith.muli %workgroup_id_y, %c3072 : index
179+
%c16 = arith.constant 16 : index
180+
%33 = arith.muli %9, %c16 : index
181+
%34 = arith.addi %32, %33 : index
182+
%c2_3 = arith.constant 2 : index
183+
%c0_4 = arith.constant 0 : index
184+
%c-1 = arith.constant -1 : index
185+
%35 = arith.cmpi slt, %6, %c0_4 : index
186+
%36 = arith.subi %c-1, %6 : index
187+
%37 = arith.select %35, %36, %6 : index
188+
%38 = arith.divsi %37, %c2_3 : index
189+
%39 = arith.subi %c-1, %38 : index
190+
%40 = arith.select %35, %39, %38 : index
191+
%41 = arith.addi %34, %40 : index
192+
%42 = vector.load %2[%41] : memref<402653184xi8, #gpu.address_space<global>>, vector<4xi8>
193+
%43 = vector.bitcast %42 : vector<4xi8> to vector<8xi4>
194+
%44 = vector.insert %43, %cst_1 [0] : vector<8xi4> into vector<2x8xi4>
195+
%45 = arith.addi %9, %c1 : index
196+
%c3072_5 = arith.constant 3072 : index
197+
%46 = arith.muli %workgroup_id_y, %c3072_5 : index
198+
%c16_6 = arith.constant 16 : index
199+
%47 = arith.muli %45, %c16_6 : index
200+
%48 = arith.addi %46, %47 : index
201+
%c2_7 = arith.constant 2 : index
202+
%c0_8 = arith.constant 0 : index
203+
%c-1_9 = arith.constant -1 : index
204+
%49 = arith.cmpi slt, %6, %c0_8 : index
205+
%50 = arith.subi %c-1_9, %6 : index
206+
%51 = arith.select %49, %50, %6 : index
207+
%52 = arith.divsi %51, %c2_7 : index
208+
%53 = arith.subi %c-1_9, %52 : index
209+
%54 = arith.select %49, %53, %52 : index
210+
%55 = arith.addi %48, %54 : index
211+
%56 = vector.load %2[%55] : memref<402653184xi8, #gpu.address_space<global>>, vector<4xi8>
212+
%57 = vector.bitcast %56 : vector<4xi8> to vector<8xi4>
213+
%58 = vector.insert %57, %44 [1] : vector<8xi4> into vector<2x8xi4>
214+
%59 = vector.bitcast %58 : vector<2x8xi4> to vector<2x4xi8>
215+
%60 = arith.andi %59, %cst_0 : vector<2x4xi8>
216+
%61 = arith.shrui %59, %cst : vector<2x4xi8>
217+
%62 = vector.interleave %60, %61 : vector<2x4xi8> -> vector<2x8xi8>
218+
%63 = arith.extui %62 : vector<2x8xi8> to vector<2x8xi32>
219+
%64 = arith.uitofp %63 : vector<2x8xi32> to vector<2x8xf16>
220+
%65 = vector.extract %20[0] : vector<2x8xf16> from vector<1x2x8xf16>
221+
%66 = arith.mulf %64, %65 : vector<2x8xf16>
222+
%67 = vector.extract %31[0] : vector<2x8xf16> from vector<1x2x8xf16>
223+
%68 = arith.addf %66, %67 : vector<2x8xf16>
224+
%69 = vector.extract %68[0] : vector<8xf16> from vector<2x8xf16>
225+
vector.store %69, %3[%workgroup_id_y, %9, %6] : memref<131072x192x32xf16, #gpu.address_space<global>>, vector<8xf16>
226+
%70 = vector.extract %68[1] : vector<8xf16> from vector<2x8xf16>
227+
vector.store %70, %3[%workgroup_id_y, %45, %6] : memref<131072x192x32xf16, #gpu.address_space<global>>, vector<8xf16>
228+
return
229+
}
230+
}
231+
// CHECK-LABEL: llvm.func @emulation_lowering(
232+
// CHECK-NOT: builtin.unrealized_conversion_cast

0 commit comments

Comments
 (0)