Skip to content

Commit eb58f82

Browse files
[LLVMGPU] Add fixes and tests for horizontally fused gemms through GPU pipeline. (#19930)
Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 04dc4a4 commit eb58f82

File tree

5 files changed

+249
-3
lines changed

5 files changed

+249
-3
lines changed

compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,14 +203,18 @@ ChangeResult DistributionLayout::resolveWithPossibleConflict(
203203
IRRewriter builder(opOperand.getOwner());
204204
// Handle case where constantOp may have multiple consumers with different
205205
// layouts by creating a copy of constOp for other users.
206-
if (!opOperand.get().hasOneUse() && !vectorLayout &&
206+
if (!opOperand.get().hasOneUse() &&
207207
llvm::isa_and_nonnull<arith::ConstantOp, vector::StepOp>(
208208
opOperand.get().getDefiningOp())) {
209209
builder.setInsertionPoint(opOperand.get().getDefiningOp());
210210
Operation *copiedConstOp = builder.clone(*opOperand.get().getDefiningOp());
211211
Value copiedConst = copiedConstOp->getResult(0);
212-
builder.replaceAllUsesExcept(opOperand.get(), copiedConst,
213-
opOperand.getOwner());
212+
DistributionLayout *newConstLayout =
213+
propagation->getLatticeElement(copiedConst);
214+
newConstLayout->subscribeEnforcement(enforcement);
215+
(void)newConstLayout->resolve(rhs);
216+
opOperand.set(copiedConst);
217+
return ChangeResult::NoChange;
214218
}
215219

216220
ResolutionResult result = doResolution(rhs);

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,7 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
838838

839839
// Set anchors at tensor level for vector distribution later and hoist out
840840
// loop invariant anchors.
841+
funcPassManager.addPass(createDecomposeHorizontallyFusedGemmsPass());
841842
funcPassManager.addPass(createLLVMGPUConfigureTensorLayoutsPass());
842843
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
843844

compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ iree_lit_test_suite(
3636
"extract_address_computation_gpu.mlir",
3737
"gpu_set_num_workgroups.mlir",
3838
"gpu_pipeline_generalize_named_ops.mlir",
39+
"horizontal_fusion_pipeline.mlir",
3940
"link_executables.mlir",
4041
"nvvm_extract_address_computation.mlir",
4142
"nvvm_pipeline_test.mlir",

compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ iree_lit_test_suite(
3333
"extract_address_computation_gpu.mlir"
3434
"gpu_pipeline_generalize_named_ops.mlir"
3535
"gpu_set_num_workgroups.mlir"
36+
"horizontal_fusion_pipeline.mlir"
3637
"illegal_configuration.mlir"
3738
"legalize.mlir"
3839
"linalg_transform.mlir"
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
// RUN: iree-opt --iree-gpu-test-target=gfx942 --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target))" %s --split-input-file | FileCheck %s
2+
3+
func.func @fused_contraction_1(%arg0: tensor<2x4096x640xf16>,
4+
%arg1 : tensor<10x64x640xf16>, %arg2 : tensor<10x64x640xf16>,
5+
%arg3 : tensor<10x64x640xf16>)
6+
-> (tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>) {
7+
%11 = tensor.empty() : tensor<2x10x4096x64xf16>
8+
%12 = tensor.empty() : tensor<2x10x4096x64xf32>
9+
%cst = arith.constant 0.0: f32
10+
%13 = linalg.fill ins(%cst : f32)
11+
outs(%12 : tensor<2x10x4096x64xf32>) -> tensor<2x10x4096x64xf32>
12+
%14:3 = linalg.generic {
13+
indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>,
14+
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
15+
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
16+
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
17+
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>,
18+
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>,
19+
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>],
20+
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]}
21+
ins(%arg0, %arg1, %arg2, %arg3
22+
: tensor<2x4096x640xf16>, tensor<10x64x640xf16>, tensor<10x64x640xf16>,
23+
tensor<10x64x640xf16>)
24+
outs(%13, %13, %13
25+
: tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>) {
26+
^bb0(%in: f16, %in_0: f16, %in_1: f16, %in_2: f16, %out: f32, %out_3: f32, %out_4: f32):
27+
%18 = arith.extf %in : f16 to f32
28+
%19 = arith.extf %in_0 : f16 to f32
29+
%20 = arith.mulf %18, %19 : f32
30+
%21 = arith.addf %out, %20 : f32
31+
%22 = arith.extf %in_1 : f16 to f32
32+
%23 = arith.mulf %18, %22 : f32
33+
%24 = arith.addf %out_3, %23 : f32
34+
%25 = arith.extf %in_2 : f16 to f32
35+
%26 = arith.mulf %18, %25 : f32
36+
%27 = arith.addf %out_4, %26 : f32
37+
linalg.yield %21, %24, %27 : f32, f32, f32
38+
} -> (tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>)
39+
%15 = linalg.generic {
40+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
41+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
42+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
43+
ins(%14#0 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) {
44+
^bb0(%in: f32, %out: f16):
45+
%18 = arith.truncf %in : f32 to f16
46+
linalg.yield %18 : f16
47+
} -> tensor<2x10x4096x64xf16>
48+
%16 = linalg.generic {
49+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
50+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
51+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
52+
ins(%14#1 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) {
53+
^bb0(%in: f32, %out: f16):
54+
%18 = arith.truncf %in : f32 to f16
55+
linalg.yield %18 : f16
56+
} -> tensor<2x10x4096x64xf16>
57+
%17 = linalg.generic {
58+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
59+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
60+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
61+
ins(%14#2 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) {
62+
^bb0(%in: f32, %out: f16):
63+
%18 = arith.truncf %in : f32 to f16
64+
linalg.yield %18 : f16
65+
} -> tensor<2x10x4096x64xf16>
66+
return %15, %16, %17
67+
: tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>
68+
}
69+
// CHECK-LABEL: func @fused_contraction_1
70+
// CHECK-COUNT-24: amdgpu.mfma
71+
72+
// -----
73+
74+
func.func @fused_contraction_2(%arg0: tensor<4096x640xf32>,
75+
%arg1 : tensor<640x640xf32>, %arg2 : tensor<640x640xf32>,
76+
%arg3 : tensor<640x640xf32>)
77+
-> (tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>) {
78+
%11 = tensor.empty() : tensor<4096x640xf32>
79+
%12 = tensor.empty() : tensor<4096x640xf32>
80+
%cst = arith.constant 0.0: f32
81+
%13 = linalg.fill ins(%cst : f32)
82+
outs(%12 : tensor<4096x640xf32>) -> tensor<4096x640xf32>
83+
%14:3 = linalg.generic {
84+
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
85+
affine_map<(d0, d1, d2) -> (d2, d1)>,
86+
affine_map<(d0, d1, d2) -> (d2, d1)>,
87+
affine_map<(d0, d1, d2) -> (d2, d1)>,
88+
affine_map<(d0, d1, d2) -> (d0, d1)>,
89+
affine_map<(d0, d1, d2) -> (d0, d1)>,
90+
affine_map<(d0, d1, d2) -> (d0, d1)>],
91+
iterator_types = ["parallel", "parallel", "reduction"]}
92+
ins(%arg0, %arg1, %arg2, %arg3
93+
: tensor<4096x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>,
94+
tensor<640x640xf32>)
95+
outs(%13, %13, %13
96+
: tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>) {
97+
^bb0(%in: f32, %in_0: f32, %in_1: f32, %in_2: f32, %out: f32, %out_3: f32, %out_4: f32):
98+
%20 = arith.mulf %in, %in_0 : f32
99+
%21 = arith.addf %out, %20 : f32
100+
%23 = arith.mulf %in, %in_1 : f32
101+
%24 = arith.addf %out_3, %23 : f32
102+
%26 = arith.mulf %in, %in_2 : f32
103+
%27 = arith.addf %out_4, %26 : f32
104+
linalg.yield %21, %24, %27 : f32, f32, f32
105+
} -> (tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>)
106+
return %14#0, %14#1, %14#2
107+
: tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>
108+
}
109+
// CHECK-LABEL: func @fused_contraction_2
110+
// CHECK-COUNT-24: amdgpu.mfma
111+
112+
// -----
113+
114+
func.func @fused_contraction_3(%arg0 : tensor<2x4096x640xi8>,
115+
%arg1 : tensor<2x640x640xi8>, %arg2 : tensor<2x640x640xi8>)
116+
-> (tensor<2x4096x640xf16>, tensor<2x4096x640xf16>) {
117+
%c0_i32 = arith.constant 0 : i32
118+
%18 = tensor.empty() : tensor<2x4096x640xf16>
119+
%19 = tensor.empty() : tensor<2x4096x640xi32>
120+
%20 = linalg.fill ins(%c0_i32 : i32)
121+
outs(%19 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
122+
%21:2 = linalg.generic {
123+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
124+
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
125+
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
126+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
127+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
128+
iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
129+
ins(%arg0, %arg1, %arg2 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>, tensor<2x640x640xi8>)
130+
outs(%20, %20 : tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) {
131+
^bb0(%in: i8, %in_0: i8, %in_1: i8, %out: i32, %out_2: i32):
132+
%24 = arith.extsi %in : i8 to i32
133+
%25 = arith.extsi %in_0 : i8 to i32
134+
%26 = arith.muli %24, %25 : i32
135+
%27 = arith.addi %out, %26 : i32
136+
%28 = arith.extsi %in_1 : i8 to i32
137+
%29 = arith.muli %24, %28 : i32
138+
%30 = arith.addi %out_2, %29 : i32
139+
linalg.yield %27, %30 : i32, i32
140+
} -> (tensor<2x4096x640xi32>, tensor<2x4096x640xi32>)
141+
%22 = linalg.generic {
142+
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
143+
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
144+
iterator_types = ["parallel", "parallel", "parallel"]}
145+
ins(%21#0 : tensor<2x4096x640xi32>) outs(%18 : tensor<2x4096x640xf16>) {
146+
^bb0(%in: i32, %out: f16):
147+
%27 = arith.sitofp %in : i32 to f32
148+
%29 = arith.truncf %27 : f32 to f16
149+
linalg.yield %29 : f16
150+
} -> tensor<2x4096x640xf16>
151+
%23 = linalg.generic {
152+
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
153+
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
154+
iterator_types = ["parallel", "parallel", "parallel"]}
155+
ins(%21#1 : tensor<2x4096x640xi32>) outs(%18 : tensor<2x4096x640xf16>) {
156+
^bb0(%in: i32, %out: f16):
157+
%27 = arith.sitofp %in : i32 to f32
158+
%29 = arith.truncf %27 : f32 to f16
159+
linalg.yield %29 : f16
160+
} -> tensor<2x4096x640xf16>
161+
return %22, %23 : tensor<2x4096x640xf16>, tensor<2x4096x640xf16>
162+
}
163+
// CHECK-LABEL: func @fused_contraction_3
164+
// CHECK-COUNT-24: amdgpu.mfma
165+
166+
// -----
167+
168+
func.func @fused_contraction_4(%arg0: tensor<2x4096x640xf16>,
169+
%arg1 : tensor<10x64x640xf16>, %arg2 : tensor<10x64x640xf16>,
170+
%arg3 : tensor<10x64x640xf16>)
171+
-> (tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>) {
172+
%9 = tensor.empty() : tensor<2x10x64x4096xf16>
173+
%10 = tensor.empty() : tensor<2x10x64x4096xf32>
174+
%11 = tensor.empty() : tensor<2x10x4096x64xf16>
175+
%12 = tensor.empty() : tensor<2x10x4096x64xf32>
176+
%cst = arith.constant 0.0: f32
177+
%fill0 = linalg.fill ins(%cst : f32)
178+
outs(%12 : tensor<2x10x4096x64xf32>) -> tensor<2x10x4096x64xf32>
179+
%fill1 = linalg.fill ins(%cst : f32)
180+
outs(%10 : tensor<2x10x64x4096xf32>) -> tensor<2x10x64x4096xf32>
181+
%14:3 = linalg.generic {
182+
indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>,
183+
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
184+
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
185+
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
186+
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>,
187+
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>,
188+
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d2)>],
189+
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]}
190+
ins(%arg0, %arg1, %arg2, %arg3
191+
: tensor<2x4096x640xf16>, tensor<10x64x640xf16>, tensor<10x64x640xf16>,
192+
tensor<10x64x640xf16>)
193+
outs(%fill0, %fill0, %fill1
194+
: tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x64x4096xf32>) {
195+
^bb0(%in: f16, %in_0: f16, %in_1: f16, %in_2: f16, %out: f32, %out_3: f32, %out_4: f32):
196+
%18 = arith.extf %in : f16 to f32
197+
%19 = arith.extf %in_0 : f16 to f32
198+
%20 = arith.mulf %18, %19 : f32
199+
%21 = arith.addf %out, %20 : f32
200+
%22 = arith.extf %in_1 : f16 to f32
201+
%23 = arith.mulf %18, %22 : f32
202+
%24 = arith.addf %out_3, %23 : f32
203+
%25 = arith.extf %in_2 : f16 to f32
204+
%26 = arith.mulf %18, %25 : f32
205+
%27 = arith.addf %out_4, %26 : f32
206+
linalg.yield %21, %24, %27 : f32, f32, f32
207+
} -> (tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x64x4096xf32>)
208+
%15 = linalg.generic {
209+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
210+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
211+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
212+
ins(%14#0 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) {
213+
^bb0(%in: f32, %out: f16):
214+
%18 = arith.truncf %in : f32 to f16
215+
linalg.yield %18 : f16
216+
} -> tensor<2x10x4096x64xf16>
217+
%16 = linalg.generic {
218+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
219+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
220+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
221+
ins(%14#1 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) {
222+
^bb0(%in: f32, %out: f16):
223+
%18 = arith.truncf %in : f32 to f16
224+
linalg.yield %18 : f16
225+
} -> tensor<2x10x4096x64xf16>
226+
%17 = linalg.generic {
227+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
228+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
229+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
230+
ins(%14#2 : tensor<2x10x64x4096xf32>) outs(%9 : tensor<2x10x64x4096xf16>) {
231+
^bb0(%in: f32, %out: f16):
232+
%18 = arith.truncf %in : f32 to f16
233+
linalg.yield %18 : f16
234+
} -> tensor<2x10x64x4096xf16>
235+
return %15, %16, %17
236+
: tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>
237+
}
238+
// CHECK-LABEL: func @fused_contraction_4
239+
// CHECK-COUNT-24: amdgpu.mfma

0 commit comments

Comments
 (0)