Skip to content

Conversation

@clementval
Copy link
Contributor

Similarly to descriptor for device data, put derived type holding device descriptor in managed memory.

@clementval clementval requested a review from wangzpgi July 2, 2025 22:44
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:semantics labels Jul 2, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 2, 2025

@llvm/pr-subscribers-flang-semantics

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Similarly to descriptor for device data, put derived type holding device descriptor in managed memory.


Full diff: https://github.com/llvm/llvm-project/pull/146797.diff

3 Files Affected:

  • (modified) flang/lib/Lower/ConvertVariable.cpp (+25-2)
  • (modified) flang/lib/Semantics/tools.cpp (+13-2)
  • (modified) flang/test/Lower/CUDA/cuda-derived.cuf (+12-2)
diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp
index 7ab3c43016bd9..44f534e7d569a 100644
--- a/flang/lib/Lower/ConvertVariable.cpp
+++ b/flang/lib/Lower/ConvertVariable.cpp
@@ -702,6 +702,29 @@ static void instantiateGlobal(Fortran::lower::AbstractConverter &converter,
   mapSymbolAttributes(converter, var, symMap, stmtCtx, cast);
 }
 
+bool needCUDAAlloc(const Fortran::semantics::Symbol &sym) {
+  if (Fortran::semantics::IsDummy(sym))
+    return false;
+  if (const auto *details{
+          sym.GetUltimate()
+              .detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
+    if (details->cudaDataAttr() &&
+        (*details->cudaDataAttr() == Fortran::common::CUDADataAttr::Device ||
+         *details->cudaDataAttr() == Fortran::common::CUDADataAttr::Managed ||
+         *details->cudaDataAttr() == Fortran::common::CUDADataAttr::Unified ||
+         *details->cudaDataAttr() == Fortran::common::CUDADataAttr::Shared ||
+         *details->cudaDataAttr() == Fortran::common::CUDADataAttr::Pinned))
+      return true;
+    const Fortran::semantics::DeclTypeSpec *type{details->type()};
+    const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived()
+                                                            : nullptr};
+    if (derived)
+      if (FindCUDADeviceAllocatableUltimateComponent(*derived))
+        return true;
+  }
+  return false;
+}
+
 //===----------------------------------------------------------------===//
 // Local variables instantiation (not for alias)
 //===----------------------------------------------------------------===//
@@ -732,7 +755,7 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
   if (ultimateSymbol.test(Fortran::semantics::Symbol::Flag::CrayPointee))
     return builder.create<fir::ZeroOp>(loc, fir::ReferenceType::get(ty));
 
-  if (Fortran::semantics::NeedCUDAAlloc(ultimateSymbol)) {
+  if (needCUDAAlloc(ultimateSymbol)) {
     cuf::DataAttributeAttr dataAttr =
         Fortran::lower::translateSymbolCUFDataAttribute(builder.getContext(),
                                                         ultimateSymbol);
@@ -1087,7 +1110,7 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
     Fortran::lower::defaultInitializeAtRuntime(converter, var.getSymbol(),
                                                symMap);
   auto *builder = &converter.getFirOpBuilder();
-  if (Fortran::semantics::NeedCUDAAlloc(var.getSymbol()) &&
+  if (needCUDAAlloc(var.getSymbol()) &&
       !cuf::isCUDADeviceContext(builder->getRegion())) {
     cuf::DataAttributeAttr dataAttr =
         Fortran::lower::translateSymbolCUFDataAttribute(builder->getContext(),
diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp
index 498bbc18709ab..ea37316e78273 100644
--- a/flang/lib/Semantics/tools.cpp
+++ b/flang/lib/Semantics/tools.cpp
@@ -1095,9 +1095,20 @@ bool IsDeviceAllocatable(const Symbol &symbol) {
 }
 
 std::optional<common::CUDADataAttr> GetCUDADataAttr(const Symbol *symbol) {
-  const auto *object{
+  const auto *details{
       symbol ? symbol->detailsIf<ObjectEntityDetails>() : nullptr};
-  return object ? object->cudaDataAttr() : std::nullopt;
+  if (details) {
+    const Fortran::semantics::DeclTypeSpec *type{details->type()};
+    const Fortran::semantics::DerivedTypeSpec *derived{
+        type ? type->AsDerived() : nullptr};
+    if (derived) {
+      if (FindCUDADeviceAllocatableUltimateComponent(*derived)) {
+        return common::CUDADataAttr::Managed;
+      }
+    }
+    return details->cudaDataAttr();
+  }
+  return std::nullopt;
 }
 
 bool IsAccessible(const Symbol &original, const Scope &scope) {
diff --git a/flang/test/Lower/CUDA/cuda-derived.cuf b/flang/test/Lower/CUDA/cuda-derived.cuf
index d280ac722d08f..96250d88d81c4 100644
--- a/flang/test/Lower/CUDA/cuda-derived.cuf
+++ b/flang/test/Lower/CUDA/cuda-derived.cuf
@@ -7,6 +7,16 @@ module m1
 
   type t1; real, device, allocatable :: a(:); end type
   type t2; type(t1) :: b; end type
+contains
+  subroutine sub1()
+    type(ty_device) :: a
+  end subroutine
+
+! CHECK-LABEL: func.func @_QMm1Psub1()
+! CHECK: %[[ALLOC:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>
+! CHECK: %[[DECL:.*]]:2 = hlfir.declare %[[ALLOC]] {data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>, !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>)
+! CHECK: cuf.free %[[DECL]]#0 : !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}>> {data_attr = #cuf.cuda<managed>}
+
 end module
 
 program main
@@ -16,5 +26,5 @@ program main
 end
 
 ! CHECK-LABEL: func.func @_QQmain() attributes {fir.bindc_name = "main"}
-! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", uniq_name = "_QFEa"}
-! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tt2{b:!fir.type<_QMm1Tt1{a:!fir.box<!fir.heap<!fir.array<?xf32>>>}>}> {bindc_name = "b", uniq_name = "_QFEb"}
+! CHECK: %{{.*}} = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QFEa"}
+! CHECK: %{{.*}} = cuf.alloc !fir.type<_QMm1Tt2{b:!fir.type<_QMm1Tt1{a:!fir.box<!fir.heap<!fir.array<?xf32>>>}>}> {bindc_name = "b", data_attr = #cuf.cuda<managed>, uniq_name = "_QFEb"}

@clementval clementval merged commit 925588c into llvm:main Jul 2, 2025
11 of 13 checks passed
@clementval clementval deleted the cuf_derived_global2 branch July 2, 2025 23:02
clementval added a commit that referenced this pull request Jul 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang:semantics flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants