@@ -77,6 +77,69 @@ static bool hasDoubleDescriptors(OpTy op) {
7777 return false ;
7878}
7979
80+ static mlir::Value createConvertOp (mlir::PatternRewriter &rewriter,
81+ mlir::Location loc, mlir::Type toTy,
82+ mlir::Value val) {
83+ if (val.getType () != toTy)
84+ return rewriter.create <fir::ConvertOp>(loc, toTy, val);
85+ return val;
86+ }
87+
88+ mlir::Value getDeviceAddress (mlir::PatternRewriter &rewriter,
89+ mlir::OpOperand &operand,
90+ const mlir::SymbolTable &symtab) {
91+ mlir::Value v = operand.get ();
92+ auto declareOp = v.getDefiningOp <fir::DeclareOp>();
93+ if (!declareOp)
94+ return v;
95+
96+ auto addrOfOp = declareOp.getMemref ().getDefiningOp <fir::AddrOfOp>();
97+ if (!addrOfOp)
98+ return v;
99+
100+ auto globalOp = symtab.lookup <fir::GlobalOp>(
101+ addrOfOp.getSymbol ().getRootReference ().getValue ());
102+
103+ if (!globalOp)
104+ return v;
105+
106+ bool isDevGlobal{false };
107+ auto attr = globalOp.getDataAttrAttr ();
108+ if (attr) {
109+ switch (attr.getValue ()) {
110+ case cuf::DataAttribute::Device:
111+ case cuf::DataAttribute::Managed:
112+ case cuf::DataAttribute::Pinned:
113+ isDevGlobal = true ;
114+ break ;
115+ default :
116+ break ;
117+ }
118+ }
119+ if (!isDevGlobal)
120+ return v;
121+ mlir::OpBuilder::InsertionGuard guard (rewriter);
122+ rewriter.setInsertionPoint (operand.getOwner ());
123+ auto loc = declareOp.getLoc ();
124+ auto mod = declareOp->getParentOfType <mlir::ModuleOp>();
125+ fir::FirOpBuilder builder (rewriter, mod);
126+
127+ mlir::func::FuncOp callee =
128+ fir::runtime::getRuntimeFunc<mkRTKey (CUFGetDeviceAddress)>(loc, builder);
129+ auto fTy = callee.getFunctionType ();
130+ auto toTy = fTy .getInput (0 );
131+ mlir::Value inputArg =
132+ createConvertOp (rewriter, loc, toTy, declareOp.getResult ());
133+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
134+ mlir::Value sourceLine =
135+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
136+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
137+ builder, loc, fTy , inputArg, sourceFile, sourceLine)};
138+ auto call = rewriter.create <fir::CallOp>(loc, callee, args);
139+
140+ return call->getResult (0 );
141+ }
142+
80143template <typename OpTy>
81144static mlir::LogicalResult convertOpToCall (OpTy op,
82145 mlir::PatternRewriter &rewriter,
@@ -363,18 +426,14 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
363426 }
364427};
365428
366- static mlir::Value createConvertOp (mlir::PatternRewriter &rewriter,
367- mlir::Location loc, mlir::Type toTy,
368- mlir::Value val) {
369- if (val.getType () != toTy)
370- return rewriter.create <fir::ConvertOp>(loc, toTy, val);
371- return val;
372- }
373-
374429struct CufDataTransferOpConversion
375430 : public mlir::OpRewritePattern<cuf::DataTransferOp> {
376431 using OpRewritePattern::OpRewritePattern;
377432
433+ CufDataTransferOpConversion (mlir::MLIRContext *context,
434+ const mlir::SymbolTable &symtab)
435+ : OpRewritePattern(context), symtab{symtab} {}
436+
378437 mlir::LogicalResult
379438 matchAndRewrite (cuf::DataTransferOp op,
380439 mlir::PatternRewriter &rewriter) const override {
@@ -445,9 +504,11 @@ struct CufDataTransferOpConversion
445504 mlir::Value sourceLine =
446505 fir::factory::locationToLineNo (builder, loc, fTy .getInput (5 ));
447506
448- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
449- builder, loc, fTy , op.getDst (), op.getSrc (), bytes, modeValue,
450- sourceFile, sourceLine)};
507+ mlir::Value dst = getDeviceAddress (rewriter, op.getDstMutable (), symtab);
508+ mlir::Value src = getDeviceAddress (rewriter, op.getSrcMutable (), symtab);
509+ llvm::SmallVector<mlir::Value> args{
510+ fir::runtime::createArguments (builder, loc, fTy , dst, src, bytes,
511+ modeValue, sourceFile, sourceLine)};
451512 builder.create <fir::CallOp>(loc, func, args);
452513 rewriter.eraseOp (op);
453514 return mlir::success ();
@@ -552,6 +613,9 @@ struct CufDataTransferOpConversion
552613 }
553614 return mlir::success ();
554615 }
616+
617+ private:
618+ const mlir::SymbolTable &symtab;
555619};
556620
557621class CufOpConversion : public fir ::impl::CufOpConversionBase<CufOpConversion> {
@@ -565,13 +629,15 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
565629 mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
566630 if (!module )
567631 return signalPassFailure ();
632+ mlir::SymbolTable symtab (module );
568633
569634 std::optional<mlir::DataLayout> dl =
570635 fir::support::getOrSetDataLayout (module , /* allowDefaultLayout=*/ false );
571636 fir::LLVMTypeConverter typeConverter (module , /* applyTBAA=*/ false ,
572637 /* forceUnifiedTBAATree=*/ false , *dl);
573638 target.addLegalDialect <fir::FIROpsDialect, mlir::arith::ArithDialect>();
574- cuf::populateCUFToFIRConversionPatterns (typeConverter, *dl, patterns);
639+ cuf::populateCUFToFIRConversionPatterns (typeConverter, *dl, symtab,
640+ patterns);
575641 if (mlir::failed (mlir::applyPartialConversion (getOperation (), target,
576642 std::move (patterns)))) {
577643 mlir::emitError (mlir::UnknownLoc::get (ctx),
@@ -584,9 +650,9 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
584650
585651void cuf::populateCUFToFIRConversionPatterns (
586652 const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
587- mlir::RewritePatternSet &patterns) {
653+ const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
588654 patterns.insert <CufAllocOpConversion>(patterns.getContext (), &dl, &converter);
589655 patterns.insert <CufAllocateOpConversion, CufDeallocateOpConversion,
590- CufFreeOpConversion, CufDataTransferOpConversion>(
591- patterns.getContext ());
656+ CufFreeOpConversion>(patterns. getContext ());
657+ patterns.insert <CufDataTransferOpConversion>(patterns. getContext (), symtab );
592658}
0 commit comments