Skip to content

Commit 9323cd2

Browse files
authored
[Codegen][VectorDistribute] Add pattern to distribute poison (#21573)
Essentially the same as the pattern to distribute constants. Also removes some unused headers from touched file. --------- Signed-off-by: James Newling <[email protected]>
1 parent d2c9cd4 commit 9323cd2

File tree

6 files changed

+62
-30
lines changed

6 files changed

+62
-30
lines changed

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ iree_compiler_cc_library(
8585
"@llvm-project//mlir:DialectUtils",
8686
"@llvm-project//mlir:IR",
8787
"@llvm-project//mlir:SCFDialect",
88+
"@llvm-project//mlir:UBDialect",
8889
"@llvm-project//mlir:VectorDialect",
8990
],
9091
)

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ iree_cc_library(
6060
MLIRAnalysis
6161
MLIRIR
6262
MLIRSCFDialect
63+
MLIRUBDialect
6364
MLIRVectorDialect
6465
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
6566
PUBLIC

compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ iree_compiler_cc_library(
4848
"@llvm-project//mlir:Pass",
4949
"@llvm-project//mlir:TransformUtils",
5050
"@llvm-project//mlir:Transforms",
51+
"@llvm-project//mlir:UBDialect",
5152
],
5253
)
5354

@@ -167,6 +168,7 @@ iree_compiler_cc_library(
167168
"@llvm-project//mlir:TilingInterface",
168169
"@llvm-project//mlir:TransformUtils",
169170
"@llvm-project//mlir:Transforms",
171+
"@llvm-project//mlir:UBDialect",
170172
"@llvm-project//mlir:ValueBoundsOpInterface",
171173
"@llvm-project//mlir:VectorDialect",
172174
"@llvm-project//mlir:VectorToSCF",

compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ iree_cc_library(
3333
MLIRPass
3434
MLIRTransformUtils
3535
MLIRTransforms
36+
MLIRUBDialect
3637
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
3738
iree::compiler::Dialect::HAL::IR
3839
iree::compiler::Utils
@@ -135,6 +136,7 @@ iree_cc_library(
135136
MLIRTilingInterface
136137
MLIRTransformUtils
137138
MLIRTransforms
139+
MLIRUBDialect
138140
MLIRValueBoundsOpInterface
139141
MLIRVectorDialect
140142
MLIRVectorToSCF

compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,16 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
#include <numeric>
87
#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
98
#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
10-
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
11-
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
12-
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
13-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
149
#include "mlir/Dialect/Affine/Utils.h"
1510
#include "mlir/Dialect/Arith/IR/Arith.h"
1611
#include "mlir/Dialect/Func/IR/FuncOps.h"
17-
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1812
#include "mlir/Dialect/SCF/IR/SCF.h"
13+
#include "mlir/Dialect/UB/IR/UBOps.h"
14+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1915
#include "mlir/IR/Attributes.h"
16+
#include "mlir/IR/OpDefinition.h"
2017
#include "mlir/IR/Verifier.h"
2118
#include "mlir/Rewrite/PatternApplicator.h"
2219

@@ -57,6 +54,30 @@ struct DistributeConstants final : OpDistributionPattern<arith::ConstantOp> {
5754
}
5855
};
5956

57+
struct DistributePoison final : OpDistributionPattern<ub::PoisonOp> {
58+
using OpDistributionPattern::OpDistributionPattern;
59+
60+
LogicalResult matchAndRewrite(ub::PoisonOp poisonOp,
61+
DistributionSignature &signature,
62+
PatternRewriter &rewriter) const override {
63+
64+
auto poisonVal = dyn_cast<VectorValue>(poisonOp.getResult());
65+
if (!poisonVal)
66+
return failure();
67+
68+
SmallVector<int64_t> distributedShape =
69+
signature[poisonVal].getDistributedShape();
70+
71+
Type elementType = poisonVal.getType().getElementType();
72+
auto vectorType = VectorType::get(distributedShape, elementType);
73+
auto distributedOp =
74+
ub::PoisonOp::create(rewriter, poisonVal.getLoc(), vectorType);
75+
replaceOpWithDistributedValues(rewriter, poisonOp,
76+
distributedOp->getResult(0));
77+
return success();
78+
}
79+
};
80+
6081
struct DistributeElementwise final
6182
: OpTraitDistributionPattern<OpTrait::Elementwise> {
6283
using OpTraitDistributionPattern::OpTraitDistributionPattern;
@@ -336,8 +357,8 @@ struct DistributeTrivialExtract final
336357
} // namespace
337358

338359
void populateGPUDistributionPatterns(RewritePatternSet &patterns) {
339-
patterns.add<DistributeConstants, DistributeScfFor, DistributeTrivialExtract>(
340-
patterns.getContext());
360+
patterns.add<DistributeConstants, DistributePoison, DistributeScfFor,
361+
DistributeTrivialExtract>(patterns.getContext());
341362
// Elementwise patterns.
342363
patterns.add<DistributeElementwise>(patterns.getContext());
343364
patterns.add<DistributeTrivialLayoutConversions>(patterns.getContext());

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --canonicalize --cse %s | FileCheck %s
1+
// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file %s | FileCheck %s
22

33
#nested = #iree_vector_ext.nested_layout<
4-
subgroup_tile = [2, 1, 1],
5-
batch_tile = [8, 2, 4],
6-
outer_tile = [1, 4, 4],
7-
thread_tile = [8, 2, 4],
4+
subgroup_tile = [2, 1, 1],
5+
batch_tile = [8, 2, 4],
6+
outer_tile = [1, 4, 4],
7+
thread_tile = [8, 2, 4],
88
element_tile = [1, 8, 2],
9-
10-
subgroup_strides = [1, 1, 1],
11-
thread_strides = [1, 8, 16]
9+
subgroup_strides = [1, 1, 1],
10+
thread_strides = [1, 8, 16]
1211
>
1312

1413
// CHECK-LABEL: @distribute_elementwise_nested_layout_f16
@@ -28,15 +27,22 @@ func.func @distribute_elementwise_nested_layout_f16(%a: vector<128x128x128xf16>,
2827
return %d : vector<128x128x128xf16>
2928
}
3029

31-
#layout = #iree_vector_ext.nested_layout<
32-
subgroup_tile = [1, 1],
33-
batch_tile = [1, 1],
34-
outer_tile = [1, 1],
35-
thread_tile = [1, 1],
36-
element_tile = [16, 16],
30+
// CHECK-LABEL: @distribute_poison
31+
func.func @distribute_poison() -> vector<128x128x128xf16> {
32+
// CHECK: ub.poison : vector<8x2x4x1x4x4x1x8x2xf16>
33+
%root = ub.poison : vector<128x128x128xf16>
34+
%rootl = iree_vector_ext.to_layout %root to layout(#nested) : vector<128x128x128xf16>
35+
return %rootl: vector<128x128x128xf16>
36+
}
3737

38+
#layout = #iree_vector_ext.nested_layout<
39+
subgroup_tile = [1, 1],
40+
batch_tile = [1, 1],
41+
outer_tile = [1, 1],
42+
thread_tile = [1, 1],
43+
element_tile = [16, 16],
3844
subgroup_strides = [1, 1],
39-
thread_strides = [1, 1]
45+
thread_strides = [1, 1]
4046
>
4147

4248
// CHECK-LABEL: @distribute_scf_for
@@ -63,14 +69,13 @@ func.func @distribute_scf_for(%a: vector<16x16xi32>, %b: vector<16x16xi32>) -> v
6369
}
6470

6571
#layout_0d = #iree_vector_ext.nested_layout<
66-
subgroup_tile = [],
67-
batch_tile = [],
68-
outer_tile = [],
69-
thread_tile = [],
70-
element_tile = [],
71-
72+
subgroup_tile = [],
73+
batch_tile = [],
74+
outer_tile = [],
75+
thread_tile = [],
76+
element_tile = [],
7277
subgroup_strides = [],
73-
thread_strides = []
78+
thread_strides = []
7479
>
7580

7681
// CHECK-LABEL: @distribute_scf_for_0d

0 commit comments

Comments
 (0)