@@ -366,6 +366,23 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
366366  const  fir::LLVMTypeConverter *typeConverter;
367367};
368368
369+ static  mlir::Value genGetDeviceAddress (mlir::PatternRewriter &rewriter,
370+                                        mlir::ModuleOp mod, mlir::Location loc,
371+                                        mlir::Value inputArg) {
372+   fir::FirOpBuilder builder (rewriter, mod);
373+   mlir::func::FuncOp callee =
374+       fir::runtime::getRuntimeFunc<mkRTKey (CUFGetDeviceAddress)>(loc, builder);
375+   auto  fTy  = callee.getFunctionType ();
376+   mlir::Value conv = createConvertOp (rewriter, loc, fTy .getInput (0 ), inputArg);
377+   mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
378+   mlir::Value sourceLine =
379+       fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
380+   llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
381+       builder, loc, fTy , conv, sourceFile, sourceLine)};
382+   auto  call = rewriter.create <fir::CallOp>(loc, callee, args);
383+   return  createConvertOp (rewriter, loc, inputArg.getType (), call->getResult (0 ));
384+ }
385+ 
369386struct  DeclareOpConversion  : public  mlir ::OpRewritePattern<fir::DeclareOp> {
370387  using  OpRewritePattern::OpRewritePattern;
371388
@@ -382,26 +399,10 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
382399        if  (cuf::isRegisteredDeviceGlobal (global)) {
383400          rewriter.setInsertionPointAfter (addrOfOp);
384401          auto  mod = op->getParentOfType <mlir::ModuleOp>();
385-           fir::FirOpBuilder builder (rewriter, mod);
386-           mlir::Location loc = op.getLoc ();
387-           mlir::func::FuncOp callee =
388-               fir::runtime::getRuntimeFunc<mkRTKey (CUFGetDeviceAddress)>(
389-                   loc, builder);
390-           auto  fTy  = callee.getFunctionType ();
391-           mlir::Type toTy = fTy .getInput (0 );
392-           mlir::Value inputArg =
393-               createConvertOp (rewriter, loc, toTy, addrOfOp.getResult ());
394-           mlir::Value sourceFile =
395-               fir::factory::locationToFilename (builder, loc);
396-           mlir::Value sourceLine =
397-               fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
398-           llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
399-               builder, loc, fTy , inputArg, sourceFile, sourceLine)};
400-           auto  call = rewriter.create <fir::CallOp>(loc, callee, args);
401-           mlir::Value cast = createConvertOp (
402-               rewriter, loc, op.getMemref ().getType (), call->getResult (0 ));
402+           mlir::Value devAddr = genGetDeviceAddress (rewriter, mod, op.getLoc (),
403+                                                     addrOfOp.getResult ());
403404          rewriter.startOpModification (op);
404-           op.getMemrefMutable ().assign (cast );
405+           op.getMemrefMutable ().assign (devAddr );
405406          rewriter.finalizeOpModification (op);
406407          return  success ();
407408        }
@@ -771,10 +772,32 @@ struct CUFLaunchOpConversion
771772            loc, clusterDimsAttr.getZ ().getInt ());
772773      }
773774    }
775+     llvm::SmallVector<mlir::Value> args;
776+     auto  mod = op->getParentOfType <mlir::ModuleOp>();
777+     for  (mlir::Value arg : op.getArgs ()) {
778+       //  If the argument is a global descriptor, make sure we pass the device
779+       //  copy of this descriptor and not the host one.
780+       if  (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType (arg.getType ()))) {
781+         if  (auto  declareOp =
782+                 mlir::dyn_cast_or_null<fir::DeclareOp>(arg.getDefiningOp ())) {
783+           if  (auto  addrOfOp = mlir::dyn_cast_or_null<fir::AddrOfOp>(
784+                   declareOp.getMemref ().getDefiningOp ())) {
785+             if  (auto  global = symTab.lookup <fir::GlobalOp>(
786+                     addrOfOp.getSymbol ().getRootReference ().getValue ())) {
787+               if  (cuf::isRegisteredDeviceGlobal (global)) {
788+                 arg = genGetDeviceAddress (rewriter, mod, op.getLoc (),
789+                                           declareOp.getResult ());
790+               }
791+             }
792+           }
793+         }
794+       }
795+       args.push_back (arg);
796+     }
797+ 
774798    auto  gpuLaunchOp = rewriter.create <mlir::gpu::LaunchFuncOp>(
775799        loc, kernelName, mlir::gpu::KernelDim3{gridSizeX, gridSizeY, gridSizeZ},
776-         mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero,
777-         op.getArgs ());
800+         mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero, args);
778801    if  (clusterDimX && clusterDimY && clusterDimZ) {
779802      gpuLaunchOp.getClusterSizeXMutable ().assign (clusterDimX);
780803      gpuLaunchOp.getClusterSizeYMutable ().assign (clusterDimY);
0 commit comments