Skip to content

Commit f59480f

Browse files
authored
[GPU] Add rematerialize parallel ops in the vector distribute pipeline (#21073)
This enables elementwise op fusion and some of cases might benefit from this. Fix for the issue: #20875
1 parent ca25934 commit f59480f

File tree

3 files changed

+81
-2
lines changed

3 files changed

+81
-2
lines changed

compiler/src/iree/compiler/Codegen/Common/RematerializeParallelOps.cpp

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,47 @@ static bool isScalarOrTensorOfSizeOne(Type t) {
2525
return t.isIntOrIndexOrFloat();
2626
}
2727

28+
/// This function checks whether the `genericOp` has any external captures,
29+
/// i.e., whether it uses any values that are defined outside of its body.
30+
/// %10 = linalg.generic {indexing_maps = [#map, #map],
31+
/// iterator_types = ["parallel", "parallel"]}
32+
/// ins(%5 : tensor<4096x64xi64>) outs(%9 : tensor<4096x64xf16>) {
33+
/// ^bb0(%in: i64, %out: f16):
34+
/// %14 = linalg.index 0 : index
35+
/// %15 = arith.index_cast %in : i64 to index
36+
/// %extracted = tensor.extract %4[%14, %15] : tensor<4096x64xf16>
37+
/// linalg.yield %extracted : f16
38+
/// } -> tensor<4096x64xf16>
39+
/// Here %4 is an external capture used via tensor.extract inside
40+
/// linalg.generic hence the above `genericOp` has an external capture.
41+
static bool hasExternalCapture(linalg::GenericOp genericOp) {
42+
Block &body = genericOp.getRegion().front();
43+
for (Operation &op : body.getOperations()) {
44+
for (Value operand : op.getOperands()) {
45+
if (auto bArg = dyn_cast<BlockArgument>(operand)) {
46+
// Check whether the operand lies in the same block.
47+
if (bArg.getOwner() == &body) {
48+
continue;
49+
}
50+
return true;
51+
}
52+
Operation *defOp = operand.getDefiningOp();
53+
// Scalar constant is allowed.
54+
if (defOp && defOp->hasTrait<mlir::OpTrait::ConstantLike>()) {
55+
Type type = operand.getType();
56+
if (type.isIntOrFloat() || type.isIndex()) {
57+
continue;
58+
}
59+
}
60+
// If defining op is not inside the block, it’s an external value.
61+
if (!defOp || defOp->getBlock() != &body) {
62+
return true;
63+
}
64+
}
65+
}
66+
return false; // All operands are locally defined or block arguments.
67+
}
68+
2869
/// Rematerialize all parallel elementwise operations into its users within a
2970
/// `flow.dispatch.region`.
3071
struct RematerializeParallelOpsPattern
@@ -44,9 +85,13 @@ struct RematerializeParallelOpsPattern
4485

4586
// Find the first operand that is defined by another generic op on tensors.
4687
for (OpOperand &opOperand : genericOp->getOpOperands()) {
47-
if (!linalg::areElementwiseOpsFusable(&opOperand))
88+
if (!linalg::areElementwiseOpsFusable(&opOperand)) {
4889
continue;
49-
90+
}
91+
auto producer = opOperand.get().getDefiningOp<linalg::GenericOp>();
92+
if (producer && hasExternalCapture(producer)) {
93+
continue;
94+
}
5095
FailureOr<linalg::ElementwiseOpFusionResult> fusionResult =
5196
linalg::fuseElementwiseOps(rewriter, &opOperand);
5297
if (succeeded(fusionResult)) {

compiler/src/iree/compiler/Codegen/Common/test/rematerialize_parallel_ops.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,34 @@ func.func @no_rematerialize_scalar_ops(%arg0 : tensor<f32>) -> tensor<f32> {
138138
// CHECK: linalg.generic
139139
// CHECK: linalg.generic
140140
// CHECK: linalg.generic
141+
142+
// -----
143+
144+
#map = affine_map<(d0, d1) -> (d0, d1)>
145+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
146+
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
147+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
148+
// Do not fuse generic that has external caputure.
149+
func.func @no_external_capture_fusion(%arg0: tensor<4096x64xi64>, %arg1: tensor<4096x64xf16>, %arg2: tensor<4096x64xf16>, %arg3: f32, %arg4: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> {
150+
%empty = tensor.empty() : tensor<4096x64xf16>
151+
%0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<4096x64xi64>) outs(%arg1 : tensor<4096x64xf16>) {
152+
^bb0(%in: i64, %out: f16):
153+
%3 = linalg.index 0 : index
154+
%4 = arith.index_cast %in : i64 to index
155+
%extracted = tensor.extract %empty[%3, %4] : tensor<4096x64xf16>
156+
linalg.yield %extracted : f16
157+
} -> tensor<4096x64xf16>
158+
%1 = linalg.fill ins(%arg3 : f32) outs(%arg4 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
159+
%2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg2, %0 : tensor<4096x64xf16>, tensor<4096x64xf16>) outs(%1 : tensor<4096x4096xf32>) {
160+
^bb0(%in: f16, %in_0: f16, %out: f32):
161+
%3 = arith.extf %in : f16 to f32
162+
%4 = arith.extf %in_0 : f16 to f32
163+
%5 = arith.mulf %3, %4 : f32
164+
%6 = arith.addf %out, %5 : f32
165+
linalg.yield %6 : f32
166+
} -> tensor<4096x4096xf32>
167+
return %2 : tensor<4096x4096xf32>
168+
}
169+
// CHECK-LABEL: func @no_external_capture_fusion(
170+
// CHECK: linalg.generic
171+
// CHECK: linalg.generic

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,9 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
865865
/*convertToDpsOptions=*/std::nullopt,
866866
/*reorderStrategy=*/reorderStrategy);
867867

868+
// Some of the elementwise fusion can benefit from this pass.
869+
funcPassManager.addPass(createRematerializeParallelOpsPass());
870+
868871
if (usePadToModelSharedMemcpy) {
869872
funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass());
870873
}

0 commit comments

Comments
 (0)