77// ===----------------------------------------------------------------------===//
88
99#include " mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h"
10- # include " ../PassDetail.h "
10+
1111#include " mlir/Conversion/LLVMCommon/ConversionTarget.h"
1212#include " mlir/Conversion/LLVMCommon/Pattern.h"
13- #include " mlir/Dialect/AMDGPU/AMDGPUDialect.h"
13+ #include " mlir/Conversion/LLVMCommon/TypeConverter.h"
14+ #include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15+ #include " mlir/Dialect/AMDGPU/Utils/Chipset.h"
16+ #include " mlir/Dialect/LLVMIR/LLVMDialect.h"
1417#include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
18+ #include " mlir/IR/BuiltinTypes.h"
19+ #include " mlir/IR/TypeUtilities.h"
20+ #include " mlir/Pass/Pass.h"
21+
22+ #include " mlir/Conversion/GPUCommon/GPUCommonPass.h"
23+ #include " mlir/Dialect/GPU/IR/GPUDialect.h"
24+ #include " mlir/Dialect/Vector/IR/VectorOps.h"
25+
26+ #include " llvm/Support/FormatVariadic.h"
27+ #include " llvm/Support/MathExtras.h"
28+ #include < cassert>
29+ #include < cstdint>
30+
31+ #include " ../LLVMCommon/MemRefDescriptor.h"
32+
33+ #include " llvm/ADT/STLExtras.h"
34+ #include < optional>
35+
36+ namespace mlir {
37+ #define GEN_PASS_DEF_CONVERTGPUTOAMDGPUPASS
38+ #include " mlir/Conversion/Passes.h.inc"
39+ } // namespace mlir
1540
1641using namespace mlir ;
1742
1843namespace {
44+ struct ClusterInfo {
45+ unsigned clusterStride;
46+ unsigned clusterSize;
47+ unsigned subgroupSize;
48+ };
49+
50+ static FailureOr<ClusterInfo>
51+ getAndValidateClusterInfo (gpu::SubgroupReduceOp op, unsigned subgroupSize) {
52+ assert (llvm::isPowerOf2_32 (subgroupSize));
53+
54+ std::optional<uint32_t > clusterSize = op.getClusterSize ();
55+ assert (!clusterSize ||
56+ llvm::isPowerOf2_32 (*clusterSize)); // Verifier should've caught this.
57+ if (clusterSize && *clusterSize > subgroupSize)
58+ return op.emitOpError ()
59+ << " cluster size " << *clusterSize
60+ << " is greater than subgroup size " << subgroupSize;
61+ unsigned effectiveClusterSize = clusterSize.value_or (subgroupSize);
62+
63+ auto clusterStride = op.getClusterStride ();
64+ assert (llvm::isPowerOf2_32 (clusterStride)); // Verifier should've caught this.
65+ if (clusterStride >= subgroupSize)
66+ return op.emitOpError ()
67+ << " cluster stride " << clusterStride
68+ << " is not less than subgroup size " << subgroupSize;
69+
70+ return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize};
71+ }
72+
73+ Value createSubgroupDPPReduction (OpBuilder &b, Location loc, Value input,
74+ gpu::AllReduceOperation mode,
75+ const ClusterInfo &ci) {
76+ Value result = input;
77+ if (ci.clusterSize >= 2 ) {
78+ auto permArg = b.getIntegerAttr (b.getIntegerType (32 ), 1 );
79+ Value dppResult =
80+ b.create <amdgpu::DPPOp>(loc, result.getType (), result, result,
81+ amdgpu::DPPPerm::row_shr, permArg);
82+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
83+ result, dppResult);
84+ }
85+
86+ if (ci.clusterSize >= 4 ) {
87+ auto permArg = b.getIntegerAttr (b.getIntegerType (32 ), 2 );
88+ Value dppResult =
89+ b.create <amdgpu::DPPOp>(loc, result.getType (), result, result,
90+ amdgpu::DPPPerm::row_shr, permArg);
91+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
92+ result, dppResult);
93+ }
94+
95+ if (ci.clusterSize >= 8 ) {
96+ Value dppResult = b.create <amdgpu::DPPOp>(
97+ loc, result.getType (), result, result, amdgpu::DPPPerm::row_half_mirror,
98+ b.getUnitAttr ());
99+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
100+ result, dppResult);
101+ }
102+
103+ if (ci.clusterSize >= 16 ) {
104+ Value dppResult =
105+ b.create <amdgpu::DPPOp>(loc, result.getType (), result, result,
106+ amdgpu::DPPPerm::row_mirror, b.getUnitAttr ());
107+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
108+ result, dppResult);
109+ }
110+
111+ if (ci.clusterSize >= 32 ) {
112+ // auto permArg = builder.getInt32(15);
113+ // auto rowMask = builder.getInt32("0xa");
114+ // auto bankMask = builder.getInt32("0xf");
115+ // auto boundCtrl = builder.getBoolAttr(false);
116+ auto permArg = b.getIntegerAttr (b.getIntegerType (32 ), 15 );
117+ Value dppResult = b.create <amdgpu::DPPOp>(
118+ loc, result.getType (), result, result, amdgpu::DPPPerm::row_bcast_15,
119+ b.getUnitAttr (), 10 , 15 , false );
120+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
121+ result, dppResult);
122+ }
123+
124+ if (ci.clusterSize == 64 ) {
125+ // auto permArg = builder.getInt32(31);
126+ // auto rowMask = builder.getInt32("0xc");
127+ // auto bankMask = builder.getInt32("0xf");
128+ // auto boundCtrl = builder.getBoolAttr(false);
129+ auto permArg = b.getIntegerAttr (b.getIntegerType (32 ), 31 );
130+ Value dppResult = b.create <amdgpu::DPPOp>(
131+ loc, result.getType (), result, result, amdgpu::DPPPerm::row_bcast_31,
132+ b.getUnitAttr (), 12 , 15 , false );
133+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
134+ result, dppResult);
135+ }
136+
137+ // // read lane 63 with the final result.
138+ // auto lane = b.getIntegerAttr(b.getIntegerType(32), 63);
139+ // result = b.create<ROCDL::ReadLaneOp>(loc, input.getType(), result, lane);
140+ assert (result.getType () == input.getType ());
141+ return result;
142+ }
143+
144+ struct ScalarSubgroupReduceToShuffles final
145+ : OpRewritePattern<gpu::SubgroupReduceOp> {
146+ ScalarSubgroupReduceToShuffles (MLIRContext *ctx, unsigned subgroupSize,
147+ bool matchClustered, PatternBenefit benefit)
148+ : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
149+ matchClustered (matchClustered) {}
150+
151+ LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
152+ PatternRewriter &rewriter) const override {
153+ llvm::errs () << " ScalarSubgroupReduceToShuffles" << " \n " ;
154+ if (op.getClusterSize ().has_value () != matchClustered) {
155+ return rewriter.notifyMatchFailure (
156+ op, llvm::formatv (" op is {0}clustered but pattern is configured to "
157+ " only match {1}clustered ops" ,
158+ matchClustered ? " non-" : " " ,
159+ matchClustered ? " " : " non-" ));
160+ }
161+
162+ auto ci = getAndValidateClusterInfo (op, subgroupSize);
163+ if (failed (ci))
164+ return failure ();
165+
166+ Location loc = op.getLoc ();
167+ rewriter.replaceOp (op, createSubgroupDPPReduction (
168+ rewriter, loc, op.getValue (), op.getOp (), *ci));
169+ return success ();
170+ }
171+
172+ private:
173+ unsigned subgroupSize = 0 ;
174+ bool matchClustered = false ;
175+ };
176+
19177struct ConvertGPUToAMDGPUPass
20- : public ConvertGPUToAMDGPUBase <ConvertGPUToAMDGPUPass> {
21- ConvertGPUToAMDGPUPass () = default ;
178+ : public impl::ConvertGPUToAMDGPUPassBase <ConvertGPUToAMDGPUPass> {
179+ using Base::Base ;
22180
23181 void runOnOperation () override {
24182 RewritePatternSet patterns (&getContext ());
25183 LLVMTypeConverter converter (&getContext ());
26- populateGPUToAMDGPUConversionPatterns (converter, patterns);
27184 LLVMConversionTarget target (getContext ());
28185 target.addLegalDialect <::mlir::LLVM::LLVMDialect>();
29- target.addLegalDialect <::mlir::AMDGPU ::AMDGPUDialect>();
186+ target.addLegalDialect <::mlir::amdgpu ::AMDGPUDialect>();
30187 target.addLegalDialect <::mlir::ROCDL::ROCDLDialect>();
188+
189+ int subgroupSizeInt = static_cast <int >(subgroupSize);
190+ populateSubgroupReduceLoweringPatterns (converter, patterns, subgroupSizeInt,
191+ PatternBenefit (1 ));
31192 if (failed (applyPartialConversion (getOperation (), target,
32193 std::move (patterns))))
33194 signalPassFailure ();
34195 }
35196};
36197} // namespace
37198
38- void mlir::populateGPUToAMDGPUConversionPatterns (
39- LLVMTypeConverter &converter, RewritePatternSet &patterns) {
40- }
41-
42- std::unique_ptr<Pass> mlir::createConvertGPUToAMDGPUPass () {
43- return std::make_unique<ConvertGPUToAMDGPUPass>();
199+ void mlir::populateSubgroupReduceLoweringPatterns (
200+ LLVMTypeConverter &converter, RewritePatternSet &patterns, unsigned subgroupSize, PatternBenefit benefit) {
201+ patterns.add <ScalarSubgroupReduceToShuffles>(
202+ patterns.getContext (), subgroupSize, /* matchClustered=*/ true , benefit);
44203}
0 commit comments