@@ -448,6 +448,53 @@ static mlir::Value getShapeFromDecl(mlir::Value src) {
448448 return mlir::Value{};
449449}
450450
451+ static mlir::Value emboxSrc (mlir::PatternRewriter &rewriter,
452+ cuf::DataTransferOp op,
453+ const mlir::SymbolTable &symtab) {
454+ auto mod = op->getParentOfType <mlir::ModuleOp>();
455+ mlir::Location loc = op.getLoc ();
456+ fir::FirOpBuilder builder (rewriter, mod);
457+ mlir::Value addr;
458+ mlir::Type srcTy = fir::unwrapRefType (op.getSrc ().getType ());
459+ if (fir::isa_trivial (srcTy) &&
460+ mlir::matchPattern (op.getSrc ().getDefiningOp (), mlir::m_Constant ())) {
461+ // Put constant in memory if it is not.
462+ mlir::Value alloc = builder.createTemporary (loc, srcTy);
463+ builder.create <fir::StoreOp>(loc, op.getSrc (), alloc);
464+ addr = alloc;
465+ } else {
466+ addr = getDeviceAddress (rewriter, op.getSrcMutable (), symtab);
467+ }
468+ llvm::SmallVector<mlir::Value> lenParams;
469+ mlir::Type boxTy = fir::BoxType::get (srcTy);
470+ mlir::Value box =
471+ builder.createBox (loc, boxTy, addr, getShapeFromDecl (op.getSrc ()),
472+ /* slice=*/ nullptr , lenParams,
473+ /* tdesc=*/ nullptr );
474+ mlir::Value src = builder.createTemporary (loc, box.getType ());
475+ builder.create <fir::StoreOp>(loc, box, src);
476+ return src;
477+ }
478+
479+ static mlir::Value emboxDst (mlir::PatternRewriter &rewriter,
480+ cuf::DataTransferOp op,
481+ const mlir::SymbolTable &symtab) {
482+ auto mod = op->getParentOfType <mlir::ModuleOp>();
483+ mlir::Location loc = op.getLoc ();
484+ fir::FirOpBuilder builder (rewriter, mod);
485+ mlir::Type dstTy = fir::unwrapRefType (op.getDst ().getType ());
486+ mlir::Value dstAddr = getDeviceAddress (rewriter, op.getDstMutable (), symtab);
487+ mlir::Type dstBoxTy = fir::BoxType::get (dstTy);
488+ llvm::SmallVector<mlir::Value> lenParams;
489+ mlir::Value dstBox =
490+ builder.createBox (loc, dstBoxTy, dstAddr, getShapeFromDecl (op.getDst ()),
491+ /* slice=*/ nullptr , lenParams,
492+ /* tdesc=*/ nullptr );
493+ mlir::Value dst = builder.createTemporary (loc, dstBox.getType ());
494+ builder.create <fir::StoreOp>(loc, dstBox, dst);
495+ return dst;
496+ }
497+
451498struct CUFDataTransferOpConversion
452499 : public mlir::OpRewritePattern<cuf::DataTransferOp> {
453500 using OpRewritePattern::OpRewritePattern;
@@ -486,10 +533,22 @@ struct CUFDataTransferOpConversion
486533 !mlir::isa<fir::BaseBoxType>(dstTy)) {
487534
488535 if (fir::isa_trivial (srcTy) && !fir::isa_trivial (dstTy)) {
489- // TODO: scalar to array data transfer.
490- mlir::emitError (loc,
491- " not yet implemented: scalar to array data transfer\n " );
492- return mlir::failure ();
536+ // Initialization of an array from a scalar value should be implemented
537+ // via a kernel launch. Use the flan runtime via the Assign function
538+ // until we have more infrastructure.
539+ mlir::Value src = emboxSrc (rewriter, op, symtab);
540+ mlir::Value dst = emboxDst (rewriter, op, symtab);
541+ mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey (
542+ CUFDataTransferDescDescNoRealloc)>(loc, builder);
543+ auto fTy = func.getFunctionType ();
544+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
545+ mlir::Value sourceLine =
546+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (4 ));
547+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
548+ builder, loc, fTy , dst, src, modeValue, sourceFile, sourceLine)};
549+ builder.create <fir::CallOp>(loc, func, args);
550+ rewriter.eraseOp (op);
551+ return mlir::success ();
493552 }
494553
495554 mlir::Type i64Ty = builder.getI64Type ();
@@ -548,29 +607,8 @@ struct CUFDataTransferOpConversion
548607 mlir::Value dst = op.getDst ();
549608 mlir::Value src = op.getSrc ();
550609
551- if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
552- // If src is not a descriptor, create one.
553- mlir::Value addr;
554- if (fir::isa_trivial (srcTy) &&
555- mlir::matchPattern (op.getSrc ().getDefiningOp (),
556- mlir::m_Constant ())) {
557- // Put constant in memory if it is not.
558- mlir::Value alloc = builder.createTemporary (loc, srcTy);
559- builder.create <fir::StoreOp>(loc, op.getSrc (), alloc);
560- addr = alloc;
561- } else {
562- addr = getDeviceAddress (rewriter, op.getSrcMutable (), symtab);
563- }
564- mlir::Type boxTy = fir::BoxType::get (srcTy);
565- llvm::SmallVector<mlir::Value> lenParams;
566- mlir::Value box =
567- builder.createBox (loc, boxTy, addr, getShapeFromDecl (src),
568- /* slice=*/ nullptr , lenParams,
569- /* tdesc=*/ nullptr );
570- mlir::Value memBox = builder.createTemporary (loc, box.getType ());
571- builder.create <fir::StoreOp>(loc, box, memBox);
572- src = memBox;
573- }
610+ if (!mlir::isa<fir::BaseBoxType>(srcTy))
611+ src = emboxSrc (rewriter, op, symtab);
574612
575613 auto fTy = func.getFunctionType ();
576614 mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
@@ -582,16 +620,7 @@ struct CUFDataTransferOpConversion
582620 rewriter.eraseOp (op);
583621 } else {
584622 // Transfer from a descriptor.
585-
586- mlir::Value addr = getDeviceAddress (rewriter, op.getDstMutable (), symtab);
587- mlir::Type boxTy = fir::BoxType::get (dstTy);
588- llvm::SmallVector<mlir::Value> lenParams;
589- mlir::Value box =
590- builder.createBox (loc, boxTy, addr, getShapeFromDecl (op.getDst ()),
591- /* slice=*/ nullptr , lenParams,
592- /* tdesc=*/ nullptr );
593- mlir::Value memBox = builder.createTemporary (loc, box.getType ());
594- builder.create <fir::StoreOp>(loc, box, memBox);
623+ mlir::Value dst = emboxDst (rewriter, op, symtab);
595624
596625 mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey (
597626 CUFDataTransferDescDescNoRealloc)>(loc, builder);
@@ -601,7 +630,7 @@ struct CUFDataTransferOpConversion
601630 mlir::Value sourceLine =
602631 fir::factory::locationToLineNo (builder, loc, fTy .getInput (4 ));
603632 llvm::SmallVector<mlir::Value> args{
604- fir::runtime::createArguments (builder, loc, fTy , memBox , op.getSrc (),
633+ fir::runtime::createArguments (builder, loc, fTy , dst , op.getSrc (),
605634 modeValue, sourceFile, sourceLine)};
606635 builder.create <fir::CallOp>(loc, func, args);
607636 rewriter.eraseOp (op);
0 commit comments