|
11 | 11 | // |
12 | 12 | //===----------------------------------------------------------------------===// |
13 | 13 |
|
14 | | -#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" |
15 | | - |
16 | | -#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" |
17 | | -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" |
18 | 14 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
19 | | -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" |
| 15 | +#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" |
20 | 16 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
21 | 17 | #include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h" |
| 18 | +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" |
22 | 19 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
23 | 20 | #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" |
24 | 21 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
25 | | -#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" |
26 | | -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" |
27 | 22 | #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" |
28 | 23 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
29 | 24 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
@@ -346,6 +341,11 @@ struct LowerGpuOpsToNVVMOpsPass |
346 | 341 | : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> { |
347 | 342 | using Base::Base; |
348 | 343 |
|
| 344 | + void getDependentDialects(DialectRegistry ®istry) const override final { |
| 345 | + Base::getDependentDialects(registry); |
| 346 | + registerConvertToLLVMDependentDialectLoading(registry); |
| 347 | + } |
| 348 | + |
349 | 349 | void runOnOperation() override { |
350 | 350 | gpu::GPUModuleOp m = getOperation(); |
351 | 351 |
|
@@ -376,17 +376,24 @@ struct LowerGpuOpsToNVVMOpsPass |
376 | 376 | LLVMTypeConverter converter(m.getContext(), options); |
377 | 377 | configureGpuToNVVMTypeConverter(converter); |
378 | 378 | RewritePatternSet llvmPatterns(m.getContext()); |
| 379 | + LLVMConversionTarget target(getContext()); |
| 380 | + |
| 381 | + for (Dialect *dialect : getContext().getLoadedDialects()) { |
| 382 | + if (isa<math::MathDialect>(dialect)) |
| 383 | + continue; |
| 384 | + |
| 385 | + auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); |
| 386 | + if (!iface) |
| 387 | + continue; |
| 388 | + |
| 389 | + iface->populateConvertToLLVMConversionPatterns(target, converter, |
| 390 | + llvmPatterns); |
| 391 | + } |
379 | 392 |
|
380 | | - arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns); |
381 | | - cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); |
382 | | - populateFuncToLLVMConversionPatterns(converter, llvmPatterns); |
383 | | - populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns); |
384 | 393 | populateGpuToNVVMConversionPatterns(converter, llvmPatterns); |
385 | 394 | populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns); |
386 | | - populateVectorToLLVMConversionPatterns(converter, llvmPatterns); |
387 | 395 | if (this->hasRedux) |
388 | 396 | populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns); |
389 | | - LLVMConversionTarget target(getContext()); |
390 | 397 | configureGpuToNVVMConversionLegality(target); |
391 | 398 | if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) |
392 | 399 | signalPassFailure(); |
@@ -472,8 +479,10 @@ void mlir::populateGpuToNVVMConversionPatterns( |
472 | 479 | using gpu::index_lowering::IndexKind; |
473 | 480 | using gpu::index_lowering::IntrType; |
474 | 481 | populateWithGenerated(patterns); |
| 482 | + |
| 483 | + // Set higher benefit, so patterns will run before generic LLVM lowering. |
475 | 484 | patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>( |
476 | | - converter); |
| 485 | + converter, /*benefit*/ 10); |
477 | 486 | patterns.add< |
478 | 487 | gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp, |
479 | 488 | NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>( |
|
0 commit comments