@@ -378,16 +378,36 @@ struct LowerGpuOpsToNVVMOpsPass
378378 RewritePatternSet llvmPatterns (m.getContext ());
379379 LLVMConversionTarget target (getContext ());
380380
381- for (Dialect *dialect : getContext ().getLoadedDialects ()) {
382- if (isa<math::MathDialect>(dialect))
383- continue ;
381+ if (!filterDialects.empty ()) {
382+ for (StringRef dialectName : filterDialects) {
383+ Dialect *dialect = getContext ().getLoadedDialect (dialectName);
384+ // Dialect may not be loaded if it wasn't used in source module, ignore.
385+ if (!dialect)
386+ continue ;
387+
388+ auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
389+ if (!iface) {
390+ m.emitError ()
391+ << " dialect does not implement ConvertToLLVMPatternInterface: "
392+ << dialectName << " \n " ;
393+ return signalPassFailure ();
394+ }
384395
385- auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
386- if (!iface)
387- continue ;
396+ iface->populateConvertToLLVMConversionPatterns (target, converter,
397+ llvmPatterns);
398+ }
399+ } else {
400+ for (Dialect *dialect : getContext ().getLoadedDialects ()) {
401+ if (isa<math::MathDialect>(dialect)) // Need custom math lowering
402+ continue ;
388403
389- iface->populateConvertToLLVMConversionPatterns (target, converter,
390- llvmPatterns);
404+ auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
405+ if (!iface)
406+ continue ;
407+
408+ iface->populateConvertToLLVMConversionPatterns (target, converter,
409+ llvmPatterns);
410+ }
391411 }
392412
393413 populateGpuToNVVMConversionPatterns (converter, llvmPatterns);
@@ -404,6 +424,7 @@ struct LowerGpuOpsToNVVMOpsPass
404424
405425void mlir::configureGpuToNVVMConversionLegality (ConversionTarget &target) {
406426 target.addIllegalOp <func::FuncOp>();
427+ target.addIllegalOp <cf::AssertOp>();
407428 target.addLegalDialect <::mlir::LLVM::LLVMDialect>();
408429 target.addLegalDialect <::mlir::NVVM::NVVMDialect>();
409430 target.addIllegalDialect <gpu::GPUDialect>();
0 commit comments