diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index 8e9de3d328152..7ecb3b1a7bf27 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -507,8 +507,11 @@ struct CUFDataTransferOpConversion using OpRewritePattern::OpRewritePattern; CUFDataTransferOpConversion(mlir::MLIRContext *context, - const mlir::SymbolTable &symtab) - : OpRewritePattern(context), symtab{symtab} {} + const mlir::SymbolTable &symtab, + mlir::DataLayout *dl, + const fir::LLVMTypeConverter *typeConverter) + : OpRewritePattern(context), symtab{symtab}, dl{dl}, + typeConverter{typeConverter} {} mlir::LogicalResult matchAndRewrite(cuf::DataTransferOp op, @@ -576,7 +579,13 @@ struct CUFDataTransferOpConversion nbElement = builder.createIntegerConstant( loc, i64Ty, seqTy.getConstantArraySize()); } - int width = computeWidth(loc, dstTy, kindMap); + unsigned width = 0; + if (fir::isa_derived(dstTy)) { + mlir::Type structTy = typeConverter->convertType(dstTy); + width = dl->getTypeSizeInBits(structTy) / 8; + } else { + width = computeWidth(loc, dstTy, kindMap); + } mlir::Value widthValue = rewriter.create( loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width)); mlir::Value bytes = @@ -647,6 +656,8 @@ struct CUFDataTransferOpConversion private: const mlir::SymbolTable &symtab; + mlir::DataLayout *dl; + const fir::LLVMTypeConverter *typeConverter; }; struct CUFLaunchOpConversion @@ -749,6 +760,7 @@ void cuf::populateCUFToFIRConversionPatterns( patterns.insert(patterns.getContext(), &dl, &converter); patterns.insert(patterns.getContext()); - patterns.insert( - patterns.getContext(), symtab); + patterns.insert(patterns.getContext(), symtab, + &dl, &converter); + patterns.insert(patterns.getContext(), symtab); } diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir index 8497aee2e2cf9..1a31c4c6d17a4 100644 --- a/flang/test/Fir/CUDA/cuda-data-transfer.fir +++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir @@ -295,4 +295,17 @@ func.func @_QPscalar_to_array() { // CHECK-LABEL: func.func @_QPscalar_to_array() // CHECK: _FortranACUFDataTransferDescDescNoRealloc +func.func @_QPtest_type() { + %0 = cuf.alloc !fir.type<_QMbarTcmplx{id:i32,c:complex}> {bindc_name = "a", data_attr = #cuf.cuda, uniq_name = "_QFtest_typeEa"} -> !fir.ref}>> + %1 = fir.declare %0 {data_attr = #cuf.cuda, uniq_name = "_QFtest_typeEa"} : (!fir.ref}>>) -> !fir.ref}>> + %2 = fir.alloca !fir.type<_QMbarTcmplx{id:i32,c:complex}> {bindc_name = "b", uniq_name = "_QFtest_typeEb"} + %3 = fir.declare %2 {uniq_name = "_QFtest_typeEb"} : (!fir.ref}>>) -> !fir.ref}>> + cuf.data_transfer %3 to %1 {transfer_kind = #cuf.cuda_transfer} : !fir.ref}>>, !fir.ref}>> + cuf.free %1 : !fir.ref}>> {data_attr = #cuf.cuda} + return +} + +// CHECK-LABEL: func.func @_QPtest_type() +// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%{{.*}}, %{{.*}}, %c12{{.*}}, %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none + } // end of module