Skip to content

Commit 1f353d4

Browse files
committed
gpu-to-rocd filter-dialects
1 parent ee88eb5 commit 1f353d4

File tree

3 files changed

+39
-14
lines changed

3 files changed

+39
-14
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -578,20 +578,24 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
578578
/*default=*/"\"gfx000\"",
579579
"Chipset that these operations will run on">,
580580
Option<"indexBitwidth", "index-bitwidth", "unsigned",
581-
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
581+
/*default=kDeriveIndexBitwidthFromDataLayout*/ "0",
582582
"Bitwidth of the index type, 0 to use size of machine word">,
583583
Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
584584
/*default=*/"false",
585585
"Replace memref arguments in GPU functions with bare pointers."
586586
"All memrefs must have static shape">,
587587
Option<"runtime", "runtime", "::mlir::gpu::amd::Runtime",
588-
"::mlir::gpu::amd::Runtime::Unknown",
589-
"Runtime code will be run on (default is Unknown, can also use HIP or OpenCl)",
590-
[{::llvm::cl::values(
591-
clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown", "Unknown (default)"),
592-
clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
593-
clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL", "OpenCL")
594-
)}]>
588+
"::mlir::gpu::amd::Runtime::Unknown",
589+
"Runtime code will be run on (default is Unknown, can also use HIP "
590+
"or OpenCl)",
591+
[{::llvm::cl::values(
592+
clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown",
593+
"Unknown (default)"),
594+
clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
595+
clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL",
596+
"OpenCL"))}]>,
597+
ListOption<"filterDialects", "filter-dialects", "std::string",
598+
"Run conversion patterns of only the specified dialects">,
595599
];
596600
}
597601

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -293,13 +293,33 @@ struct LowerGpuOpsToROCDLOpsPass
293293
RewritePatternSet llvmPatterns(ctx);
294294
LLVMConversionTarget target(getContext());
295295

296-
for (Dialect *dialect : ctx->getLoadedDialects()) {
297-
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
298-
if (!iface)
299-
continue;
296+
if (!filterDialects.empty()) {
297+
for (StringRef dialectName : filterDialects) {
298+
Dialect *dialect = ctx->getLoadedDialect(dialectName);
299+
// Dialect may not be loaded if it wasn't used in source module, ignore.
300+
if (!dialect)
301+
continue;
302+
303+
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
304+
if (!iface) {
305+
m.emitError()
306+
<< "dialect does not implement ConvertToLLVMPatternInterface: "
307+
<< dialectName << "\n";
308+
return signalPassFailure();
309+
}
300310

301-
iface->populateConvertToLLVMConversionPatterns(target, converter,
302-
llvmPatterns);
311+
iface->populateConvertToLLVMConversionPatterns(target, converter,
312+
llvmPatterns);
313+
}
314+
} else {
315+
for (Dialect *dialect : ctx->getLoadedDialects()) {
316+
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
317+
if (!iface)
318+
continue;
319+
320+
iface->populateConvertToLLVMConversionPatterns(target, converter,
321+
llvmPatterns);
322+
}
303323
}
304324

305325
populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -convert-gpu-to-rocdl='filter-dialects=func,arith,math' -split-input-file | FileCheck %s
23
// RUN: mlir-opt %s -convert-gpu-to-rocdl='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
34

45
// CHECK-LABEL: @test_module

0 commit comments

Comments
 (0)