@@ -366,6 +366,23 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
366366 const fir::LLVMTypeConverter *typeConverter;
367367};
368368
369+ static mlir::Value genGetDeviceAddress (mlir::PatternRewriter &rewriter,
370+ mlir::ModuleOp mod, mlir::Location loc,
371+ mlir::Value inputArg) {
372+ fir::FirOpBuilder builder (rewriter, mod);
373+ mlir::func::FuncOp callee =
374+ fir::runtime::getRuntimeFunc<mkRTKey (CUFGetDeviceAddress)>(loc, builder);
375+ auto fTy = callee.getFunctionType ();
376+ mlir::Value conv = createConvertOp (rewriter, loc, fTy .getInput (0 ), inputArg);
377+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
378+ mlir::Value sourceLine =
379+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
380+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
381+ builder, loc, fTy , conv, sourceFile, sourceLine)};
382+ auto call = rewriter.create <fir::CallOp>(loc, callee, args);
383+ return createConvertOp (rewriter, loc, inputArg.getType (), call->getResult (0 ));
384+ }
385+
369386struct DeclareOpConversion : public mlir ::OpRewritePattern<fir::DeclareOp> {
370387 using OpRewritePattern::OpRewritePattern;
371388
@@ -382,26 +399,10 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
382399 if (cuf::isRegisteredDeviceGlobal (global)) {
383400 rewriter.setInsertionPointAfter (addrOfOp);
384401 auto mod = op->getParentOfType <mlir::ModuleOp>();
385- fir::FirOpBuilder builder (rewriter, mod);
386- mlir::Location loc = op.getLoc ();
387- mlir::func::FuncOp callee =
388- fir::runtime::getRuntimeFunc<mkRTKey (CUFGetDeviceAddress)>(
389- loc, builder);
390- auto fTy = callee.getFunctionType ();
391- mlir::Type toTy = fTy .getInput (0 );
392- mlir::Value inputArg =
393- createConvertOp (rewriter, loc, toTy, addrOfOp.getResult ());
394- mlir::Value sourceFile =
395- fir::factory::locationToFilename (builder, loc);
396- mlir::Value sourceLine =
397- fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
398- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
399- builder, loc, fTy , inputArg, sourceFile, sourceLine)};
400- auto call = rewriter.create <fir::CallOp>(loc, callee, args);
401- mlir::Value cast = createConvertOp (
402- rewriter, loc, op.getMemref ().getType (), call->getResult (0 ));
402+ mlir::Value devAddr = genGetDeviceAddress (rewriter, mod, op.getLoc (),
403+ addrOfOp.getResult ());
403404 rewriter.startOpModification (op);
404- op.getMemrefMutable ().assign (cast );
405+ op.getMemrefMutable ().assign (devAddr );
405406 rewriter.finalizeOpModification (op);
406407 return success ();
407408 }
@@ -771,10 +772,32 @@ struct CUFLaunchOpConversion
771772 loc, clusterDimsAttr.getZ ().getInt ());
772773 }
773774 }
775+ llvm::SmallVector<mlir::Value> args;
776+ auto mod = op->getParentOfType <mlir::ModuleOp>();
777+ for (mlir::Value arg : op.getArgs ()) {
778+ // If the argument is a global descriptor, make sure we pass the device
779+ // copy of this descriptor and not the host one.
780+ if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType (arg.getType ()))) {
781+ if (auto declareOp =
782+ mlir::dyn_cast_or_null<fir::DeclareOp>(arg.getDefiningOp ())) {
783+ if (auto addrOfOp = mlir::dyn_cast_or_null<fir::AddrOfOp>(
784+ declareOp.getMemref ().getDefiningOp ())) {
785+ if (auto global = symTab.lookup <fir::GlobalOp>(
786+ addrOfOp.getSymbol ().getRootReference ().getValue ())) {
787+ if (cuf::isRegisteredDeviceGlobal (global)) {
788+ arg = genGetDeviceAddress (rewriter, mod, op.getLoc (),
789+ declareOp.getResult ());
790+ }
791+ }
792+ }
793+ }
794+ }
795+ args.push_back (arg);
796+ }
797+
774798 auto gpuLaunchOp = rewriter.create <mlir::gpu::LaunchFuncOp>(
775799 loc, kernelName, mlir::gpu::KernelDim3{gridSizeX, gridSizeY, gridSizeZ},
776- mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero,
777- op.getArgs ());
800+ mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero, args);
778801 if (clusterDimX && clusterDimY && clusterDimZ) {
779802 gpuLaunchOp.getClusterSizeXMutable ().assign (clusterDimX);
780803 gpuLaunchOp.getClusterSizeYMutable ().assign (clusterDimY);
0 commit comments