Skip to content

Conversation

clementval
Copy link
Contributor

When the sea value in the cuf.data_transfer is the result of a fir.load operation, get the memref from the fir.load. Otherwise the conversion fails with an invalid conversion from fir.box to fir.ref.

@clementval clementval requested a review from wangzpgi October 13, 2025 20:32
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Oct 13, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 13, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

When the sea value in the cuf.data_transfer is the result of a fir.load operation, get the memref from the fir.load. Otherwise the conversion fails with an invalid conversion from fir.box to fir.ref.


Full diff: https://github.com/llvm/llvm-project/pull/163262.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/Transforms/CUFOpConversion.cpp (+3)
  • (modified) flang/test/Fir/CUDA/cuda-data-transfer.fir (+34)
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index e5c5ba9082426..759e3a65dd24f 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -741,6 +741,9 @@ struct CUFDataTransferOpConversion
         fir::StoreOp::create(builder, loc, val, box);
         return box;
       }
+      if (mlir::isa<fir::BaseBoxType>(val.getType()))
+        if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(val.getDefiningOp()))
+          return loadOp.getMemref();
       return val;
     };
 
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index 5d3215dd07fce..b247fce44df3d 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -691,5 +691,39 @@ func.func @_QPtesti4(%arg0: !fir.ref<i32> {fir.bindc_name = "n1"}, %arg1: !fir.r
 // CHECK-LABEL: func.func @_QPtesti4
 // CHECK: fir.call @_FortranACUFDataTransferCstDesc
 
+// -----
+
+func.func @_QQmain() attributes {fir.bindc_name = "T"} {
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+  %c80 = arith.constant 80 : index
+  %c0 = arith.constant 0 : index
+  %0 = fir.dummy_scope : !fir.dscope
+  %1 = cuf.alloc !fir.box<!fir.heap<!fir.array<?x?x?xf16>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFEa"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x?xf16>>>>
+  %2 = fir.zero_bits !fir.heap<!fir.array<?x?x?xf16>>
+  %3 = fir.shape %c0, %c0, %c0 : (index, index, index) -> !fir.shape<3>
+  %4 = fir.embox %2(%3) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?x?x?xf16>>, !fir.shape<3>) -> !fir.box<!fir.heap<!fir.array<?x?x?xf16>>>
+  fir.store %4 to %1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x?xf16>>>>
+  %5 = fir.declare %1 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFEa"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?x?xf16>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x?xf16>>>>
+  %6 = fir.address_of(@_QFEha) : !fir.ref<!fir.array<80x80x80xf32>>
+  %7 = fir.shape %c80, %c80, %c80 : (index, index, index) -> !fir.shape<3>
+  %8 = fir.declare %6(%7) {uniq_name = "_QFEha"} : (!fir.ref<!fir.array<80x80x80xf32>>, !fir.shape<3>) -> !fir.ref<!fir.array<80x80x80xf32>>
+  %9 = fir.address_of(@_QFECn) : !fir.ref<i32>
+  %10 = fir.declare %9 {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QFECn"} : (!fir.ref<i32>) -> !fir.ref<i32>
+  %11 = fir.load %5 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x?xf16>>>>
+  %12:3 = fir.box_dims %11, %c0 : (!fir.box<!fir.heap<!fir.array<?x?x?xf16>>>, index) -> (index, index, index)
+  %13:3 = fir.box_dims %11, %c1 : (!fir.box<!fir.heap<!fir.array<?x?x?xf16>>>, index) -> (index, index, index)
+  %14:3 = fir.box_dims %11, %c2 : (!fir.box<!fir.heap<!fir.array<?x?x?xf16>>>, index) -> (index, index, index)
+  %15 = fir.shape %12#1, %13#1, %14#1 : (index, index, index) -> !fir.shape<3>
+  %16 = fir.allocmem !fir.array<?x?x?xf16>, %12#1, %13#1, %14#1 {bindc_name = ".tmp", uniq_name = ""}
+  %17 = fir.declare %16(%15) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<?x?x?xf16>>, !fir.shape<3>) -> !fir.heap<!fir.array<?x?x?xf16>>
+  %18 = fir.embox %17(%15) : (!fir.heap<!fir.array<?x?x?xf16>>, !fir.shape<3>) -> !fir.box<!fir.array<?x?x?xf16>>
+  cuf.data_transfer %11 to %18 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.box<!fir.heap<!fir.array<?x?x?xf16>>>, !fir.box<!fir.array<?x?x?xf16>>
+  return
+}
+
+// CHECK-LABEL: func.func @_QQmain() 
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc
+
 } // end of module
 

@clementval clementval merged commit b5d75b2 into llvm:main Oct 13, 2025
13 checks passed
akadutta pushed a commit to akadutta/llvm-project that referenced this pull request Oct 14, 2025
When the sea value in the `cuf.data_transfer` is the result of a
`fir.load` operation, get the memref from the `fir.load`. Otherwise the
conversion fails with an invalid conversion from `fir.box` to `fir.ref`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants