@@ -366,22 +366,47 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
366366 const fir::LLVMTypeConverter *typeConverter;
367367};
368368
369- static mlir::Value genGetDeviceAddress (mlir::PatternRewriter &rewriter,
370- mlir::ModuleOp mod, mlir::Location loc,
371- mlir::Value inputArg) {
372- fir::FirOpBuilder builder (rewriter, mod);
373- mlir::func::FuncOp callee =
374- fir::runtime::getRuntimeFunc<mkRTKey (CUFGetDeviceAddress)>(loc, builder);
375- auto fTy = callee.getFunctionType ();
376- mlir::Value conv = createConvertOp (rewriter, loc, fTy .getInput (0 ), inputArg);
377- mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
378- mlir::Value sourceLine =
379- fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
380- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
381- builder, loc, fTy , conv, sourceFile, sourceLine)};
382- auto call = rewriter.create <fir::CallOp>(loc, callee, args);
383- return createConvertOp (rewriter, loc, inputArg.getType (), call->getResult (0 ));
384- }
369+ struct CUFDeviceAddressOpConversion
370+ : public mlir::OpRewritePattern<cuf::DeviceAddressOp> {
371+ using OpRewritePattern::OpRewritePattern;
372+
373+ CUFDeviceAddressOpConversion (mlir::MLIRContext *context,
374+ const mlir::SymbolTable &symtab)
375+ : OpRewritePattern(context), symTab{symtab} {}
376+
377+ mlir::LogicalResult
378+ matchAndRewrite (cuf::DeviceAddressOp op,
379+ mlir::PatternRewriter &rewriter) const override {
380+ if (auto global = symTab.lookup <fir::GlobalOp>(
381+ op.getHostSymbol ().getRootReference ().getValue ())) {
382+ auto mod = op->getParentOfType <mlir::ModuleOp>();
383+ mlir::Location loc = op.getLoc ();
384+ auto hostAddr = rewriter.create <fir::AddrOfOp>(
385+ loc, fir::ReferenceType::get (global.getType ()), op.getHostSymbol ());
386+ fir::FirOpBuilder builder (rewriter, mod);
387+ mlir::func::FuncOp callee =
388+ fir::runtime::getRuntimeFunc<mkRTKey (CUFGetDeviceAddress)>(loc,
389+ builder);
390+ auto fTy = callee.getFunctionType ();
391+ mlir::Value conv =
392+ createConvertOp (rewriter, loc, fTy .getInput (0 ), hostAddr);
393+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
394+ mlir::Value sourceLine =
395+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
396+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
397+ builder, loc, fTy , conv, sourceFile, sourceLine)};
398+ auto call = rewriter.create <fir::CallOp>(loc, callee, args);
399+ mlir::Value addr = createConvertOp (rewriter, loc, hostAddr.getType (),
400+ call->getResult (0 ));
401+ rewriter.replaceOp (op, addr.getDefiningOp ());
402+ return success ();
403+ }
404+ return failure ();
405+ }
406+
407+ private:
408+ const mlir::SymbolTable &symTab;
409+ };
385410
386411struct DeclareOpConversion : public mlir ::OpRewritePattern<fir::DeclareOp> {
387412 using OpRewritePattern::OpRewritePattern;
@@ -398,9 +423,8 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
398423 addrOfOp.getSymbol ().getRootReference ().getValue ())) {
399424 if (cuf::isRegisteredDeviceGlobal (global)) {
400425 rewriter.setInsertionPointAfter (addrOfOp);
401- auto mod = op->getParentOfType <mlir::ModuleOp>();
402- mlir::Value devAddr = genGetDeviceAddress (rewriter, mod, op.getLoc (),
403- addrOfOp.getResult ());
426+ mlir::Value devAddr = rewriter.create <cuf::DeviceAddressOp>(
427+ op.getLoc (), addrOfOp.getType (), addrOfOp.getSymbol ());
404428 rewriter.startOpModification (op);
405429 op.getMemrefMutable ().assign (devAddr);
406430 rewriter.finalizeOpModification (op);
@@ -773,7 +797,6 @@ struct CUFLaunchOpConversion
773797 }
774798 }
775799 llvm::SmallVector<mlir::Value> args;
776- auto mod = op->getParentOfType <mlir::ModuleOp>();
777800 for (mlir::Value arg : op.getArgs ()) {
778801 // If the argument is a global descriptor, make sure we pass the device
779802 // copy of this descriptor and not the host one.
@@ -785,8 +808,11 @@ struct CUFLaunchOpConversion
785808 if (auto global = symTab.lookup <fir::GlobalOp>(
786809 addrOfOp.getSymbol ().getRootReference ().getValue ())) {
787810 if (cuf::isRegisteredDeviceGlobal (global)) {
788- arg = genGetDeviceAddress (rewriter, mod, op.getLoc (),
789- declareOp.getResult ());
811+ arg = rewriter
812+ .create <cuf::DeviceAddressOp>(op.getLoc (),
813+ addrOfOp.getType (),
814+ addrOfOp.getSymbol ())
815+ .getResult ();
790816 }
791817 }
792818 }
@@ -907,10 +933,12 @@ void cuf::populateCUFToFIRConversionPatterns(
907933 patterns.getContext ());
908934 patterns.insert <CUFDataTransferOpConversion>(patterns.getContext (), symtab,
909935 &dl, &converter);
910- patterns.insert <CUFLaunchOpConversion>(patterns.getContext (), symtab);
936+ patterns.insert <CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(
937+ patterns.getContext (), symtab);
911938}
912939
913940void cuf::populateFIRCUFConversionPatterns (const mlir::SymbolTable &symtab,
914941 mlir::RewritePatternSet &patterns) {
915- patterns.insert <DeclareOpConversion>(patterns.getContext (), symtab);
942+ patterns.insert <DeclareOpConversion, CUFDeviceAddressOpConversion>(
943+ patterns.getContext (), symtab);
916944}
0 commit comments