Skip to content

Commit d655d8b

Browse files
committed
refac dialect filtering
1 parent f11c4c0 commit d655d8b

File tree

2 files changed

+39
-47
lines changed

2 files changed

+39
-47
lines changed

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -378,37 +378,33 @@ struct LowerGpuOpsToNVVMOpsPass final
378378
RewritePatternSet llvmPatterns(m.getContext());
379379
LLVMConversionTarget target(getContext());
380380

381-
if (!allowedDialects.empty()) {
382-
for (StringRef dialectName : allowedDialects) {
383-
Dialect *dialect = getContext().getLoadedDialect(dialectName);
384-
// Dialect may not be loaded if it wasn't used in source module, ignore.
385-
if (!dialect)
386-
continue;
387-
388-
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
389-
if (!iface) {
381+
llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
382+
allowedDialects.end());
383+
for (Dialect *dialect : getContext().getLoadedDialects()) {
384+
// Skip math patterns as nvvm needs custom math lowering.
385+
if (isa<math::MathDialect>(dialect))
386+
continue;
387+
388+
bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
389+
// Empty `allowedDialectsSet` means all dialects are allowed.
390+
if (!allowedDialectsSet.empty() && !allowed)
391+
continue;
392+
393+
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
394+
if (!iface) {
395+
// Error out if dialect was explicily specified but doesn't implement
396+
// conversion interface.
397+
if (allowed) {
390398
m.emitError()
391399
<< "dialect does not implement ConvertToLLVMPatternInterface: "
392-
<< dialectName;
400+
<< dialect->getNamespace();
393401
return signalPassFailure();
394402
}
395-
396-
iface->populateConvertToLLVMConversionPatterns(target, converter,
397-
llvmPatterns);
403+
continue;
398404
}
399-
} else {
400-
for (Dialect *dialect : getContext().getLoadedDialects()) {
401-
// Skip math patterns as nvvm needs custom math lowering.
402-
if (isa<math::MathDialect>(dialect))
403-
continue;
404405

405-
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
406-
if (!iface)
407-
continue;
408-
409-
iface->populateConvertToLLVMConversionPatterns(target, converter,
410-
llvmPatterns);
411-
}
406+
iface->populateConvertToLLVMConversionPatterns(target, converter,
407+
llvmPatterns);
412408
}
413409

414410
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -293,33 +293,29 @@ struct LowerGpuOpsToROCDLOpsPass final
293293
RewritePatternSet llvmPatterns(ctx);
294294
LLVMConversionTarget target(getContext());
295295

296-
if (!allowedDialects.empty()) {
297-
for (StringRef dialectName : allowedDialects) {
298-
Dialect *dialect = ctx->getLoadedDialect(dialectName);
299-
// Dialect may not be loaded if it wasn't used in source module, ignore.
300-
if (!dialect)
301-
continue;
302-
303-
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
304-
if (!iface) {
296+
llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
297+
allowedDialects.end());
298+
for (Dialect *dialect : ctx->getLoadedDialects()) {
299+
bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
300+
// Empty `allowedDialectsSet` means all dialects are allowed.
301+
if (!allowedDialectsSet.empty() && !allowed)
302+
continue;
303+
304+
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
305+
if (!iface) {
306+
// Error out if dialect was explicily specified but doesn't implement
307+
// conversion interface.
308+
if (allowed) {
305309
m.emitError()
306310
<< "dialect does not implement ConvertToLLVMPatternInterface: "
307-
<< dialectName;
311+
<< dialect->getNamespace();
308312
return signalPassFailure();
309313
}
310-
311-
iface->populateConvertToLLVMConversionPatterns(target, converter,
312-
llvmPatterns);
313-
}
314-
} else {
315-
for (Dialect *dialect : ctx->getLoadedDialects()) {
316-
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
317-
if (!iface)
318-
continue;
319-
320-
iface->populateConvertToLLVMConversionPatterns(target, converter,
321-
llvmPatterns);
314+
continue;
322315
}
316+
317+
iface->populateConvertToLLVMConversionPatterns(target, converter,
318+
llvmPatterns);
323319
}
324320

325321
populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,

0 commit comments

Comments
 (0)