@@ -81,6 +81,15 @@ static bool hasDoubleDescriptors(OpTy op) {
8181 return false ;
8282}
8383
84+ bool isDeviceGlobal (fir::GlobalOp op) {
85+ auto attr = op.getDataAttr ();
86+ if (attr && (*attr == cuf::DataAttribute::Device ||
87+ *attr == cuf::DataAttribute::Managed ||
88+ *attr == cuf::DataAttribute::Constant))
89+ return true ;
90+ return false ;
91+ }
92+
8493static mlir::Value createConvertOp (mlir::PatternRewriter &rewriter,
8594 mlir::Location loc, mlir::Type toTy,
8695 mlir::Value val) {
@@ -89,62 +98,6 @@ static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
8998 return val;
9099}
91100
92- mlir::Value getDeviceAddress (mlir::PatternRewriter &rewriter,
93- mlir::OpOperand &operand,
94- const mlir::SymbolTable &symtab) {
95- mlir::Value v = operand.get ();
96- auto declareOp = v.getDefiningOp <fir::DeclareOp>();
97- if (!declareOp)
98- return v;
99-
100- auto addrOfOp = declareOp.getMemref ().getDefiningOp <fir::AddrOfOp>();
101- if (!addrOfOp)
102- return v;
103-
104- auto globalOp = symtab.lookup <fir::GlobalOp>(
105- addrOfOp.getSymbol ().getRootReference ().getValue ());
106-
107- if (!globalOp)
108- return v;
109-
110- bool isDevGlobal{false };
111- auto attr = globalOp.getDataAttrAttr ();
112- if (attr) {
113- switch (attr.getValue ()) {
114- case cuf::DataAttribute::Device:
115- case cuf::DataAttribute::Managed:
116- case cuf::DataAttribute::Constant:
117- isDevGlobal = true ;
118- break ;
119- default :
120- break ;
121- }
122- }
123- if (!isDevGlobal)
124- return v;
125- mlir::OpBuilder::InsertionGuard guard (rewriter);
126- rewriter.setInsertionPoint (operand.getOwner ());
127- auto loc = declareOp.getLoc ();
128- auto mod = declareOp->getParentOfType <mlir::ModuleOp>();
129- fir::FirOpBuilder builder (rewriter, mod);
130-
131- mlir::func::FuncOp callee =
132- fir::runtime::getRuntimeFunc<mkRTKey (CUFGetDeviceAddress)>(loc, builder);
133- auto fTy = callee.getFunctionType ();
134- auto toTy = fTy .getInput (0 );
135- mlir::Value inputArg =
136- createConvertOp (rewriter, loc, toTy, declareOp.getResult ());
137- mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
138- mlir::Value sourceLine =
139- fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
140- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
141- builder, loc, fTy , inputArg, sourceFile, sourceLine)};
142- auto call = rewriter.create <fir::CallOp>(loc, callee, args);
143- mlir::Value cast = createConvertOp (
144- rewriter, loc, declareOp.getMemref ().getType (), call->getResult (0 ));
145- return cast;
146- }
147-
148101template <typename OpTy>
149102static mlir::LogicalResult convertOpToCall (OpTy op,
150103 mlir::PatternRewriter &rewriter,
@@ -422,6 +375,54 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
422375 const fir::LLVMTypeConverter *typeConverter;
423376};
424377
378+ struct DeclareOpConversion : public mlir ::OpRewritePattern<fir::DeclareOp> {
379+ using OpRewritePattern::OpRewritePattern;
380+
381+ DeclareOpConversion (mlir::MLIRContext *context,
382+ const mlir::SymbolTable &symtab)
383+ : OpRewritePattern(context), symTab{symtab} {}
384+
385+ mlir::LogicalResult
386+ matchAndRewrite (fir::DeclareOp op,
387+ mlir::PatternRewriter &rewriter) const override {
388+ if (auto addrOfOp = op.getMemref ().getDefiningOp <fir::AddrOfOp>()) {
389+ if (auto global = symTab.lookup <fir::GlobalOp>(
390+ addrOfOp.getSymbol ().getRootReference ().getValue ())) {
391+ if (isDeviceGlobal (global)) {
392+ rewriter.setInsertionPointAfter (addrOfOp);
393+ auto mod = op->getParentOfType <mlir::ModuleOp>();
394+ fir::FirOpBuilder builder (rewriter, mod);
395+ mlir::Location loc = op.getLoc ();
396+ mlir::func::FuncOp callee =
397+ fir::runtime::getRuntimeFunc<mkRTKey (CUFGetDeviceAddress)>(
398+ loc, builder);
399+ auto fTy = callee.getFunctionType ();
400+ mlir::Type toTy = fTy .getInput (0 );
401+ mlir::Value inputArg =
402+ createConvertOp (rewriter, loc, toTy, addrOfOp.getResult ());
403+ mlir::Value sourceFile =
404+ fir::factory::locationToFilename (builder, loc);
405+ mlir::Value sourceLine =
406+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
407+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
408+ builder, loc, fTy , inputArg, sourceFile, sourceLine)};
409+ auto call = rewriter.create <fir::CallOp>(loc, callee, args);
410+ mlir::Value cast = createConvertOp (
411+ rewriter, loc, op.getMemref ().getType (), call->getResult (0 ));
412+ rewriter.startOpModification (op);
413+ op.getMemrefMutable ().assign (cast);
414+ rewriter.finalizeOpModification (op);
415+ return success ();
416+ }
417+ }
418+ }
419+ return failure ();
420+ }
421+
422+ private:
423+ const mlir::SymbolTable &symTab;
424+ };
425+
425426struct CUFFreeOpConversion : public mlir ::OpRewritePattern<cuf::FreeOp> {
426427 using OpRewritePattern::OpRewritePattern;
427428
@@ -511,7 +512,7 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
511512 builder.create <fir::StoreOp>(loc, src, alloc);
512513 addr = alloc;
513514 } else {
514- addr = getDeviceAddress (rewriter, op.getSrcMutable (), symtab );
515+ addr = op.getSrc ( );
515516 }
516517 llvm::SmallVector<mlir::Value> lenParams;
517518 mlir::Type boxTy = fir::BoxType::get (srcTy);
@@ -531,7 +532,7 @@ static mlir::Value emboxDst(mlir::PatternRewriter &rewriter,
531532 mlir::Location loc = op.getLoc ();
532533 fir::FirOpBuilder builder (rewriter, mod);
533534 mlir::Type dstTy = fir::unwrapRefType (op.getDst ().getType ());
534- mlir::Value dstAddr = getDeviceAddress (rewriter, op.getDstMutable (), symtab );
535+ mlir::Value dstAddr = op.getDst ( );
535536 mlir::Type dstBoxTy = fir::BoxType::get (dstTy);
536537 llvm::SmallVector<mlir::Value> lenParams;
537538 mlir::Value dstBox =
@@ -652,8 +653,8 @@ struct CUFDataTransferOpConversion
652653 mlir::Value sourceLine =
653654 fir::factory::locationToLineNo (builder, loc, fTy .getInput (5 ));
654655
655- mlir::Value dst = getDeviceAddress (rewriter, op.getDstMutable (), symtab );
656- mlir::Value src = getDeviceAddress (rewriter, op.getSrcMutable (), symtab );
656+ mlir::Value dst = op.getDst ( );
657+ mlir::Value src = op.getSrc ( );
657658 // Materialize the src if constant.
658659 if (matchPattern (src.getDefiningOp (), mlir::m_Constant ())) {
659660 mlir::Value temp = builder.createTemporary (loc, srcTy);
@@ -823,6 +824,30 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
823824 " error in CUF op conversion\n " );
824825 signalPassFailure ();
825826 }
827+
828+ target.addDynamicallyLegalOp <fir::DeclareOp>([&](fir::DeclareOp op) {
829+ if (inDeviceContext (op))
830+ return true ;
831+ if (auto addrOfOp = op.getMemref ().getDefiningOp <fir::AddrOfOp>()) {
832+ if (auto global = symtab.lookup <fir::GlobalOp>(
833+ addrOfOp.getSymbol ().getRootReference ().getValue ())) {
834+ if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType (global.getType ())))
835+ return true ;
836+ if (isDeviceGlobal (global))
837+ return false ;
838+ }
839+ }
840+ return true ;
841+ });
842+
843+ patterns.clear ();
844+ cuf::populateFIRCUFConversionPatterns (symtab, patterns);
845+ if (mlir::failed (mlir::applyPartialConversion (getOperation (), target,
846+ std::move (patterns)))) {
847+ mlir::emitError (mlir::UnknownLoc::get (ctx),
848+ " error in CUF op conversion\n " );
849+ signalPassFailure ();
850+ }
826851 }
827852};
828853} // namespace
@@ -837,3 +862,8 @@ void cuf::populateCUFToFIRConversionPatterns(
837862 &dl, &converter);
838863 patterns.insert <CUFLaunchOpConversion>(patterns.getContext (), symtab);
839864}
865+
866+ void cuf::populateFIRCUFConversionPatterns (const mlir::SymbolTable &symtab,
867+ mlir::RewritePatternSet &patterns) {
868+ patterns.insert <DeclareOpConversion>(patterns.getContext (), symtab);
869+ }
0 commit comments