88
99#include " mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h"
1010
11- #include " mlir/Conversion/LLVMCommon/ConversionTarget.h"
12- #include " mlir/Conversion/LLVMCommon/Pattern.h"
13- #include " mlir/Conversion/LLVMCommon/TypeConverter.h"
1411#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15- #include " mlir/Dialect/AMDGPU/Utils/Chipset.h"
1612#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
1713#include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
1814#include " mlir/IR/BuiltinTypes.h"
2319#include " mlir/Dialect/GPU/IR/GPUDialect.h"
2420#include " mlir/Dialect/Vector/IR/VectorOps.h"
2521
22+ #include " mlir/Transforms/WalkPatternRewriteDriver.h"
2623#include " llvm/Support/FormatVariadic.h"
27- #include " llvm/Support/MathExtras.h"
28- #include < cassert>
29- #include < cstdint>
30-
31- #include " ../LLVMCommon/MemRefDescriptor.h"
32-
33- #include " llvm/ADT/STLExtras.h"
34- #include < optional>
3524
3625namespace mlir {
3726#define GEN_PASS_DEF_CONVERTGPUTOAMDGPUPASS
@@ -180,24 +169,17 @@ struct ConvertGPUToAMDGPUPass
180169
181170 void runOnOperation () override {
182171 RewritePatternSet patterns (&getContext ());
183- LLVMTypeConverter converter (&getContext ());
184- LLVMConversionTarget target (getContext ());
185- target.addLegalDialect <::mlir::LLVM::LLVMDialect>();
186- target.addLegalDialect <::mlir::amdgpu::AMDGPUDialect>();
187- target.addLegalDialect <::mlir::ROCDL::ROCDLDialect>();
188-
189172 int subgroupSizeInt = static_cast <int >(subgroupSize);
190- populateSubgroupReduceLoweringPatterns (converter, patterns, subgroupSizeInt,
173+ populateSubgroupReduceLoweringPatterns (patterns, subgroupSizeInt,
191174 PatternBenefit (1 ));
192- if (failed (applyPartialConversion (getOperation (), target,
193- std::move (patterns))))
194- signalPassFailure ();
175+ walkAndApplyPatterns (getOperation (), std::move (patterns));
195176 }
196177};
197178} // namespace
198179
199- void mlir::populateSubgroupReduceLoweringPatterns (
200- LLVMTypeConverter &converter, RewritePatternSet &patterns, unsigned subgroupSize, PatternBenefit benefit) {
180+ void mlir::populateSubgroupReduceLoweringPatterns (RewritePatternSet &patterns,
181+ unsigned subgroupSize,
182+ PatternBenefit benefit) {
201183 patterns.add <ScalarSubgroupReduceToShuffles>(
202184 patterns.getContext (), subgroupSize, /* matchClustered=*/ true , benefit);
203185}
0 commit comments