Skip to content

Commit ee88eb5

Browse files
committed
[mlir][gpu] GPUToROCDL/NVVM: use generic llvm conversion interface instead of hardcoded connversions.
1 parent 7aabbf2 commit ee88eb5

File tree

2 files changed

+40
-27
lines changed

2 files changed

+40
-27
lines changed

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,14 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
15-
16-
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
17-
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
1814
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
19-
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
15+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
2016
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
2117
#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
18+
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
2219
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
2320
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
2421
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
25-
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
26-
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
2722
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
2823
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
2924
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -346,6 +341,11 @@ struct LowerGpuOpsToNVVMOpsPass
346341
: public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
347342
using Base::Base;
348343

344+
void getDependentDialects(DialectRegistry &registry) const override final {
345+
Base::getDependentDialects(registry);
346+
registerConvertToLLVMDependentDialectLoading(registry);
347+
}
348+
349349
void runOnOperation() override {
350350
gpu::GPUModuleOp m = getOperation();
351351

@@ -376,17 +376,24 @@ struct LowerGpuOpsToNVVMOpsPass
376376
LLVMTypeConverter converter(m.getContext(), options);
377377
configureGpuToNVVMTypeConverter(converter);
378378
RewritePatternSet llvmPatterns(m.getContext());
379+
LLVMConversionTarget target(getContext());
380+
381+
for (Dialect *dialect : getContext().getLoadedDialects()) {
382+
if (isa<math::MathDialect>(dialect))
383+
continue;
384+
385+
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
386+
if (!iface)
387+
continue;
388+
389+
iface->populateConvertToLLVMConversionPatterns(target, converter,
390+
llvmPatterns);
391+
}
379392

380-
arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
381-
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
382-
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
383-
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
384393
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
385394
populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
386-
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
387395
if (this->hasRedux)
388396
populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
389-
LLVMConversionTarget target(getContext());
390397
configureGpuToNVVMConversionLegality(target);
391398
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
392399
signalPassFailure();
@@ -472,8 +479,10 @@ void mlir::populateGpuToNVVMConversionPatterns(
472479
using gpu::index_lowering::IndexKind;
473480
using gpu::index_lowering::IntrType;
474481
populateWithGenerated(patterns);
482+
483+
// Set higher benefit, so patterns will run before generic LLVM lowering.
475484
patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
476-
converter);
485+
converter, /*benefit*/ 10);
477486
patterns.add<
478487
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
479488
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,22 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
1514
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
1615
#include "mlir/Dialect/Arith/Transforms/Passes.h"
1716
#include "mlir/Pass/Pass.h"
1817
#include "mlir/Pass/PassManager.h"
1918
#include "mlir/Transforms/Passes.h"
2019

2120
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
22-
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
23-
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
21+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
22+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
2423
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
2524
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
2625
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
2726
#include "mlir/Conversion/LLVMCommon/Pattern.h"
2827
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
2928
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
3029
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
31-
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
32-
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
3330
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
3431
#include "mlir/Dialect/Func/IR/FuncOps.h"
3532
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -218,6 +215,11 @@ struct LowerGpuOpsToROCDLOpsPass
218215
this->runtime = runtime;
219216
}
220217

218+
void getDependentDialects(DialectRegistry &registry) const override final {
219+
Base::getDependentDialects(registry);
220+
registerConvertToLLVMDependentDialectLoading(registry);
221+
}
222+
221223
void runOnOperation() override {
222224
gpu::GPUModuleOp m = getOperation();
223225
MLIRContext *ctx = m.getContext();
@@ -289,18 +291,20 @@ struct LowerGpuOpsToROCDLOpsPass
289291
});
290292

291293
RewritePatternSet llvmPatterns(ctx);
294+
LLVMConversionTarget target(getContext());
295+
296+
for (Dialect *dialect : ctx->getLoadedDialects()) {
297+
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
298+
if (!iface)
299+
continue;
300+
301+
iface->populateConvertToLLVMConversionPatterns(target, converter,
302+
llvmPatterns);
303+
}
292304

293-
mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
294305
populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
295306
*maybeChipset);
296-
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
297-
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
298-
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
299-
cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
300-
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
301-
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
302307
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
303-
LLVMConversionTarget target(getContext());
304308
configureGpuToROCDLConversionLegality(target);
305309
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
306310
signalPassFailure();

0 commit comments

Comments
 (0)