@@ -507,8 +507,11 @@ struct CUFDataTransferOpConversion
507507 using OpRewritePattern::OpRewritePattern;
508508
509509 CUFDataTransferOpConversion (mlir::MLIRContext *context,
510- const mlir::SymbolTable &symtab)
511- : OpRewritePattern(context), symtab{symtab} {}
510+ const mlir::SymbolTable &symtab,
511+ mlir::DataLayout *dl,
512+ const fir::LLVMTypeConverter *typeConverter)
513+ : OpRewritePattern(context), symtab{symtab}, dl{dl},
514+ typeConverter{typeConverter} {}
512515
513516 mlir::LogicalResult
514517 matchAndRewrite (cuf::DataTransferOp op,
@@ -576,7 +579,13 @@ struct CUFDataTransferOpConversion
576579 nbElement = builder.createIntegerConstant (
577580 loc, i64Ty, seqTy.getConstantArraySize ());
578581 }
579- int width = computeWidth (loc, dstTy, kindMap);
582+ unsigned width = 0 ;
583+ if (fir::isa_derived (dstTy)) {
584+ mlir::Type structTy = typeConverter->convertType (dstTy);
585+ width = dl->getTypeSizeInBits (structTy) / 8 ;
586+ } else {
587+ width = computeWidth (loc, dstTy, kindMap);
588+ }
580589 mlir::Value widthValue = rewriter.create <mlir::arith::ConstantOp>(
581590 loc, i64Ty, rewriter.getIntegerAttr (i64Ty, width));
582591 mlir::Value bytes =
@@ -647,6 +656,8 @@ struct CUFDataTransferOpConversion
647656
648657private:
649658 const mlir::SymbolTable &symtab;
659+ mlir::DataLayout *dl;
660+ const fir::LLVMTypeConverter *typeConverter;
650661};
651662
652663struct CUFLaunchOpConversion
@@ -749,6 +760,7 @@ void cuf::populateCUFToFIRConversionPatterns(
749760 patterns.insert <CUFAllocOpConversion>(patterns.getContext (), &dl, &converter);
750761 patterns.insert <CUFAllocateOpConversion, CUFDeallocateOpConversion,
751762 CUFFreeOpConversion>(patterns.getContext ());
752- patterns.insert <CUFDataTransferOpConversion, CUFLaunchOpConversion>(
753- patterns.getContext (), symtab);
763+ patterns.insert <CUFDataTransferOpConversion>(patterns.getContext (), symtab,
764+ &dl, &converter);
765+ patterns.insert <CUFLaunchOpConversion>(patterns.getContext (), symtab);
754766}
0 commit comments