@@ -378,37 +378,33 @@ struct LowerGpuOpsToNVVMOpsPass final
378378 RewritePatternSet llvmPatterns (m.getContext ());
379379 LLVMConversionTarget target (getContext ());
380380
381- if (!allowedDialects.empty ()) {
382- for (StringRef dialectName : allowedDialects) {
383- Dialect *dialect = getContext ().getLoadedDialect (dialectName);
384- // Dialect may not be loaded if it wasn't used in source module, ignore.
385- if (!dialect)
386- continue ;
387-
388- auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
389- if (!iface) {
381+ llvm::SmallDenseSet<StringRef> allowedDialectsSet (allowedDialects.begin (),
382+ allowedDialects.end ());
383+ for (Dialect *dialect : getContext ().getLoadedDialects ()) {
384+ // Skip math patterns as nvvm needs custom math lowering.
385+ if (isa<math::MathDialect>(dialect))
386+ continue ;
387+
388+ bool allowed = allowedDialectsSet.contains (dialect->getNamespace ());
389+ // Empty `allowedDialectsSet` means all dialects are allowed.
390+ if (!allowedDialectsSet.empty () && !allowed)
391+ continue ;
392+
393+ auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
394+ if (!iface) {
395+ // Error out if dialect was explicily specified but doesn't implement
396+ // conversion interface.
397+ if (allowed) {
390398 m.emitError ()
391399 << " dialect does not implement ConvertToLLVMPatternInterface: "
392- << dialectName ;
400+ << dialect-> getNamespace () ;
393401 return signalPassFailure ();
394402 }
395-
396- iface->populateConvertToLLVMConversionPatterns (target, converter,
397- llvmPatterns);
403+ continue ;
398404 }
399- } else {
400- for (Dialect *dialect : getContext ().getLoadedDialects ()) {
401- // Skip math patterns as nvvm needs custom math lowering.
402- if (isa<math::MathDialect>(dialect))
403- continue ;
404405
405- auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
406- if (!iface)
407- continue ;
408-
409- iface->populateConvertToLLVMConversionPatterns (target, converter,
410- llvmPatterns);
411- }
406+ iface->populateConvertToLLVMConversionPatterns (target, converter,
407+ llvmPatterns);
412408 }
413409
414410 populateGpuToNVVMConversionPatterns (converter, llvmPatterns);
0 commit comments