Skip to content

Commit 1b3f7e4

Browse files
committed
gpu-to-nvvm filter-dialects
1 parent 1f353d4 commit 1b3f7e4

File tree

3 files changed

+34
-10
lines changed

3 files changed

+34
-10
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,14 +550,16 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
550550
];
551551
let options = [
552552
Option<"indexBitwidth", "index-bitwidth", "unsigned",
553-
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
553+
/*default=kDeriveIndexBitwidthFromDataLayout*/ "0",
554554
"Bitwidth of the index type, 0 to use size of machine word">,
555555
Option<"hasRedux", "has-redux", "bool", /*default=*/"false",
556556
"Target gpu supports redux">,
557557
Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
558558
/*default=*/"false",
559559
"Replace memref arguments in GPU functions with bare pointers. "
560-
"All memrefs must have static shape.">
560+
"All memrefs must have static shape.">,
561+
ListOption<"filterDialects", "filter-dialects", "std::string",
562+
"Run conversion patterns of only the specified dialects">,
561563
];
562564
}
563565

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -378,16 +378,36 @@ struct LowerGpuOpsToNVVMOpsPass
378378
RewritePatternSet llvmPatterns(m.getContext());
379379
LLVMConversionTarget target(getContext());
380380

381-
for (Dialect *dialect : getContext().getLoadedDialects()) {
382-
if (isa<math::MathDialect>(dialect))
383-
continue;
381+
if (!filterDialects.empty()) {
382+
for (StringRef dialectName : filterDialects) {
383+
Dialect *dialect = getContext().getLoadedDialect(dialectName);
384+
// Dialect may not be loaded if it wasn't used in source module, ignore.
385+
if (!dialect)
386+
continue;
387+
388+
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
389+
if (!iface) {
390+
m.emitError()
391+
<< "dialect does not implement ConvertToLLVMPatternInterface: "
392+
<< dialectName << "\n";
393+
return signalPassFailure();
394+
}
384395

385-
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
386-
if (!iface)
387-
continue;
396+
iface->populateConvertToLLVMConversionPatterns(target, converter,
397+
llvmPatterns);
398+
}
399+
} else {
400+
for (Dialect *dialect : getContext().getLoadedDialects()) {
401+
if (isa<math::MathDialect>(dialect)) // Need custom math lowering
402+
continue;
388403

389-
iface->populateConvertToLLVMConversionPatterns(target, converter,
390-
llvmPatterns);
404+
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
405+
if (!iface)
406+
continue;
407+
408+
iface->populateConvertToLLVMConversionPatterns(target, converter,
409+
llvmPatterns);
410+
}
391411
}
392412

393413
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
@@ -404,6 +424,7 @@ struct LowerGpuOpsToNVVMOpsPass
404424

405425
void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
406426
target.addIllegalOp<func::FuncOp>();
427+
target.addIllegalOp<cf::AssertOp>();
407428
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
408429
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
409430
target.addIllegalDialect<gpu::GPUDialect>();

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1' -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 filter-dialects=func,arith,cf' -split-input-file | FileCheck %s
23
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 use-bare-ptr-memref-call-conv=1' -split-input-file | FileCheck %s --check-prefix=CHECK-BARE
34
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
45

0 commit comments

Comments
 (0)