Skip to content

Commit 2eb3cb8

Browse files
making things compile
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent f63ab5d commit 2eb3cb8

File tree

5 files changed

+190
-19
lines changed

5 files changed

+190
-19
lines changed

mlir/include/mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,27 @@
88
#ifndef MLIR_CONVERSION_GPUTOAMDGPU_GPUTOAMDGPU_H_
99
#define MLIR_CONVERSION_GPUTOAMDGPU_GPUTOAMDGPU_H_
1010

11+
12+
#include "mlir/IR/PatternMatch.h"
1113
#include <memory>
14+
#include <string>
1215

1316
namespace mlir {
1417

1518
class LLVMTypeConverter;
1619
class RewritePatternSet;
20+
class TypeConverter;
1721
class Pass;
1822

19-
void populateGPUToAMDGPUConversionPatterns(LLVMTypeConverter &converter,
20-
RewritePatternSet &patterns);
23+
#define GEN_PASS_DECL_CONVERTGPUTOAMDGPUPASS
24+
#include "mlir/Conversion/Passes.h.inc"
2125

22-
std::unique_ptr<Pass> createConvertGPUToAMDGPUPass();
26+
void populateSubgroupReduceLoweringPatterns(LLVMTypeConverter &converter,
27+
RewritePatternSet &patterns,
28+
unsigned subgroupSize,
29+
PatternBenefit benefit);
30+
// void populateGPUToAMDGPUConversionPatterns(LLVMTypeConverter &converter,
31+
// RewritePatternSet &patterns);
2332

2433
} // namespace mlir
2534

mlir/include/mlir/Conversion/Passes.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -650,12 +650,13 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
650650
def ConvertGPUToAMDGPUPass : Pass<"convert-gpu-to-amdgpu"> {
651651
let summary = "Generate AMDGPU operations for gpu operations";
652652
let dependentDialects = [
653+
"LLVM::LLVMDialect",
653654
"::mlir::gpu::GPUDialect",
654-
"amdgpu::AMDGPUDialect"
655+
"amdgpu::AMDGPUDialect",
655656
];
656-
// let options = [Option<"chipset", "chipset", "std::string",
657-
// /*default=*/"\"gfx000\"",
658-
// "Chipset that these operations will run on">];
657+
let options = [Option<"subgroupSize", "subgroup-size", "unsigned",
658+
/*default=*/"64",
659+
"Size of subgroup">];
659660
}
660661

661662
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/GPUToAMDGPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRGPUToAMDGPU
1515
MLIRLLVMDialect
1616
MLIRGPUDialect
1717
MLIRAMDGPUDialect
18+
MLIRAMDGPUUtils
1819
MLIRROCDLDialect
1920
MLIRPass
2021
MLIRTransforms

mlir/lib/Conversion/GPUToAMDGPU/GPUToAMDGPU.cpp

Lines changed: 171 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,38 +7,197 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h"
10-
#include "../PassDetail.h"
10+
1111
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1212
#include "mlir/Conversion/LLVMCommon/Pattern.h"
13-
#include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
13+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
16+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1417
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18+
#include "mlir/IR/BuiltinTypes.h"
19+
#include "mlir/IR/TypeUtilities.h"
20+
#include "mlir/Pass/Pass.h"
21+
22+
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
23+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
24+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
25+
26+
#include "llvm/Support/FormatVariadic.h"
27+
#include "llvm/Support/MathExtras.h"
28+
#include <cassert>
29+
#include <cstdint>
30+
31+
#include "../LLVMCommon/MemRefDescriptor.h"
32+
33+
#include "llvm/ADT/STLExtras.h"
34+
#include <optional>
35+
36+
namespace mlir {
37+
#define GEN_PASS_DEF_CONVERTGPUTOAMDGPUPASS
38+
#include "mlir/Conversion/Passes.h.inc"
39+
} // namespace mlir
1540

1641
using namespace mlir;
1742

1843
namespace {
44+
struct ClusterInfo {
45+
unsigned clusterStride;
46+
unsigned clusterSize;
47+
unsigned subgroupSize;
48+
};
49+
50+
static FailureOr<ClusterInfo>
51+
getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) {
52+
assert(llvm::isPowerOf2_32(subgroupSize));
53+
54+
std::optional<uint32_t> clusterSize = op.getClusterSize();
55+
assert(!clusterSize ||
56+
llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this.
57+
if (clusterSize && *clusterSize > subgroupSize)
58+
return op.emitOpError()
59+
<< "cluster size " << *clusterSize
60+
<< " is greater than subgroup size " << subgroupSize;
61+
unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
62+
63+
auto clusterStride = op.getClusterStride();
64+
assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this.
65+
if (clusterStride >= subgroupSize)
66+
return op.emitOpError()
67+
<< "cluster stride " << clusterStride
68+
<< " is not less than subgroup size " << subgroupSize;
69+
70+
return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize};
71+
}
72+
73+
Value createSubgroupDPPReduction(OpBuilder &b, Location loc, Value input,
74+
gpu::AllReduceOperation mode,
75+
const ClusterInfo &ci) {
76+
Value result = input;
77+
if (ci.clusterSize >= 2) {
78+
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 1);
79+
Value dppResult =
80+
b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
81+
amdgpu::DPPPerm::row_shr, permArg);
82+
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
83+
result, dppResult);
84+
}
85+
86+
if (ci.clusterSize >= 4) {
87+
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 2);
88+
Value dppResult =
89+
b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
90+
amdgpu::DPPPerm::row_shr, permArg);
91+
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
92+
result, dppResult);
93+
}
94+
95+
if (ci.clusterSize >= 8) {
96+
Value dppResult = b.create<amdgpu::DPPOp>(
97+
loc, result.getType(), result, result, amdgpu::DPPPerm::row_half_mirror,
98+
b.getUnitAttr());
99+
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
100+
result, dppResult);
101+
}
102+
103+
if (ci.clusterSize >= 16) {
104+
Value dppResult =
105+
b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
106+
amdgpu::DPPPerm::row_mirror, b.getUnitAttr());
107+
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
108+
result, dppResult);
109+
}
110+
111+
if (ci.clusterSize >= 32) {
112+
// auto permArg = builder.getInt32(15);
113+
// auto rowMask = builder.getInt32("0xa");
114+
// auto bankMask = builder.getInt32("0xf");
115+
// auto boundCtrl = builder.getBoolAttr(false);
116+
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 15);
117+
Value dppResult = b.create<amdgpu::DPPOp>(
118+
loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_15,
119+
b.getUnitAttr(), 10, 15, false);
120+
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
121+
result, dppResult);
122+
}
123+
124+
if (ci.clusterSize == 64) {
125+
// auto permArg = builder.getInt32(31);
126+
// auto rowMask = builder.getInt32("0xc");
127+
// auto bankMask = builder.getInt32("0xf");
128+
// auto boundCtrl = builder.getBoolAttr(false);
129+
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 31);
130+
Value dppResult = b.create<amdgpu::DPPOp>(
131+
loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_31,
132+
b.getUnitAttr(), 12, 15, false);
133+
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
134+
result, dppResult);
135+
}
136+
137+
// // read lane 63 with the final result.
138+
// auto lane = b.getIntegerAttr(b.getIntegerType(32), 63);
139+
// result = b.create<ROCDL::ReadLaneOp>(loc, input.getType(), result, lane);
140+
assert(result.getType() == input.getType());
141+
return result;
142+
}
143+
144+
struct ScalarSubgroupReduceToShuffles final
145+
: OpRewritePattern<gpu::SubgroupReduceOp> {
146+
ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
147+
bool matchClustered, PatternBenefit benefit)
148+
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
149+
matchClustered(matchClustered) {}
150+
151+
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
152+
PatternRewriter &rewriter) const override {
153+
llvm::errs() << "ScalarSubgroupReduceToShuffles" << "\n";
154+
if (op.getClusterSize().has_value() != matchClustered) {
155+
return rewriter.notifyMatchFailure(
156+
op, llvm::formatv("op is {0}clustered but pattern is configured to "
157+
"only match {1}clustered ops",
158+
matchClustered ? "non-" : "",
159+
matchClustered ? "" : "non-"));
160+
}
161+
162+
auto ci = getAndValidateClusterInfo(op, subgroupSize);
163+
if (failed(ci))
164+
return failure();
165+
166+
Location loc = op.getLoc();
167+
rewriter.replaceOp(op, createSubgroupDPPReduction(
168+
rewriter, loc, op.getValue(), op.getOp(), *ci));
169+
return success();
170+
}
171+
172+
private:
173+
unsigned subgroupSize = 0;
174+
bool matchClustered = false;
175+
};
176+
19177
struct ConvertGPUToAMDGPUPass
20-
: public ConvertGPUToAMDGPUBase<ConvertGPUToAMDGPUPass> {
21-
ConvertGPUToAMDGPUPass() = default;
178+
: public impl::ConvertGPUToAMDGPUPassBase<ConvertGPUToAMDGPUPass> {
179+
using Base::Base;
22180

23181
void runOnOperation() override {
24182
RewritePatternSet patterns(&getContext());
25183
LLVMTypeConverter converter(&getContext());
26-
populateGPUToAMDGPUConversionPatterns(converter, patterns);
27184
LLVMConversionTarget target(getContext());
28185
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
29-
target.addLegalDialect<::mlir::AMDGPU::AMDGPUDialect>();
186+
target.addLegalDialect<::mlir::amdgpu::AMDGPUDialect>();
30187
target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
188+
189+
int subgroupSizeInt = static_cast<int>(subgroupSize);
190+
populateSubgroupReduceLoweringPatterns(converter, patterns, subgroupSizeInt,
191+
PatternBenefit(1));
31192
if (failed(applyPartialConversion(getOperation(), target,
32193
std::move(patterns))))
33194
signalPassFailure();
34195
}
35196
};
36197
} // namespace
37198

38-
void mlir::populateGPUToAMDGPUConversionPatterns(
39-
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
40-
}
41-
42-
std::unique_ptr<Pass> mlir::createConvertGPUToAMDGPUPass() {
43-
return std::make_unique<ConvertGPUToAMDGPUPass>();
199+
void mlir::populateSubgroupReduceLoweringPatterns(
200+
LLVMTypeConverter &converter, RewritePatternSet &patterns, unsigned subgroupSize, PatternBenefit benefit) {
201+
patterns.add<ScalarSubgroupReduceToShuffles>(
202+
patterns.getContext(), subgroupSize, /*matchClustered=*/true, benefit);
44203
}

mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRGPUToROCDLTransforms
1515
MLIRMathToLLVM
1616
MLIRMathToROCDL
1717
MLIRAMDGPUToROCDL
18+
MLIRGPUToAMDGPU
1819
MLIRFuncToLLVM
1920
MLIRGPUDialect
2021
MLIRGPUToGPURuntimeTransforms

0 commit comments

Comments
 (0)