From 717429e5e5dfd8e22787822f23e192258960f050 Mon Sep 17 00:00:00 2001 From: David Truby Date: Tue, 29 Oct 2024 12:55:22 +0000 Subject: [PATCH 1/3] [flang] AArch64 support for BIND(C) derived return types This patch adds support for BIND(C) derived types as return values matching the AArch64 Procedure Call Standard for C. Support for BIND(C) derived types as value parameters will be in a separate patch. --- flang/lib/Optimizer/CodeGen/Target.cpp | 42 ++++++ flang/test/Fir/struct-return-aarch64.fir | 156 +++++++++++++++++++++++ 2 files changed, 198 insertions(+) create mode 100644 flang/test/Fir/struct-return-aarch64.fir diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp index 6c148dffb0e55..15ffdb74ef51d 100644 --- a/flang/lib/Optimizer/CodeGen/Target.cpp +++ b/flang/lib/Optimizer/CodeGen/Target.cpp @@ -825,6 +825,48 @@ struct TargetAArch64 : public GenericTarget { } return marshal; } + + static bool isHFA(fir::RecordType ty) { + auto types = ty.getTypeList(); + if (types.empty() || types.size() > 4) { + return false; + } + + if (!isa_real(types.front().second)) { + types.front().second.dump(); + return false; + } + + return llvm::all_equal(llvm::make_second_range(types)); + } + + CodeGenSpecifics::Marshalling + structReturnType(mlir::Location loc, fir::RecordType ty) const override { + CodeGenSpecifics::Marshalling marshal; + + if (isHFA(ty)) { + auto newTy = fir::SequenceType::get({ty.getNumFields()}, ty.getType(0)); + marshal.emplace_back(newTy, AT{}); + return marshal; + } + + auto [size, align] = + fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap); + + // return in registers if size <= 16 bytes + if (size <= 16) { + auto dwordSize = (size + 7) / 8; + auto newTy = fir::SequenceType::get( + dwordSize, mlir::IntegerType::get(ty.getContext(), 64)); + marshal.emplace_back(newTy, AT{}); + return marshal; + } + + unsigned short stackAlign = std::max(align, 8u); + marshal.emplace_back(fir::ReferenceType::get(ty), + AT{stackAlign, false, true}); + return marshal; + } }; } // namespace diff --git a/flang/test/Fir/struct-return-aarch64.fir b/flang/test/Fir/struct-return-aarch64.fir new file mode 100644 index 0000000000000..96f2f9999b343 --- /dev/null +++ b/flang/test/Fir/struct-return-aarch64.fir @@ -0,0 +1,156 @@ +// Test AArch64 ABI rewrite of struct returned by value (BIND(C), VALUE derived types). +// RUN: fir-opt --target-rewrite="target=aarch64-unknown-linux-gnu" %s | FileCheck %s + +!composite = !fir.type +// CHECK-LABEL: func.func private @test_composite() -> !fir.array<2xi64> +func.func private @test_composite() -> !composite +// CHECK-LABEL: func.func @test_call_composite( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) +func.func @test_call_composite(%arg0 : !fir.ref) { + // CHECK: %[[OUT:.*]] = fir.call @test_composite() : () -> !fir.array<2xi64> + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<2xi64> + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + %out = fir.call @test_composite() : () -> !composite + // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref> + fir.store %out to %arg0 : !fir.ref + // CHECK: return + return +} + +!hfa_f16 = !fir.type +// CHECK-LABEL: func.func private @test_hfa_f16() -> !fir.array<2xf16> +func.func private @test_hfa_f16() -> !hfa_f16 +// CHECK-LABEL: func.func @test_call_hfa_f16( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) { +func.func @test_call_hfa_f16(%arg0 : !fir.ref) { + // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f16() : () -> !fir.array<2xf16> + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<2xf16> + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + %out = fir.call @test_hfa_f16() : () -> !hfa_f16 + // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref> + fir.store %out to %arg0 : !fir.ref + return +} + +!hfa_f32 = !fir.type +// CHECK-LABEL: func.func private @test_hfa_f32() -> !fir.array<4xf32> +func.func private @test_hfa_f32() -> !hfa_f32 +// CHECK-LABEL: func.func @test_call_hfa_f32( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) { +func.func @test_call_hfa_f32(%arg0 : !fir.ref) { + // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f32() : () -> !fir.array<4xf32> + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xf32> + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + %out = fir.call @test_hfa_f32() : () -> !hfa_f32 + // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref> + fir.store %out to %arg0 : !fir.ref + return +} + +!hfa_f64 = !fir.type +// CHECK-LABEL: func.func private @test_hfa_f64() -> !fir.array<3xf64> +func.func private @test_hfa_f64() -> !hfa_f64 +// CHECK-LABEL: func.func @test_call_hfa_f64( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) +func.func @test_call_hfa_f64(%arg0 : !fir.ref) { + // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f64() : () -> !fir.array<3xf64> + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<3xf64> + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + %out = fir.call @test_hfa_f64() : () -> !hfa_f64 + // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref> + fir.store %out to %arg0 : !fir.ref + return +} + +!hfa_f128 = !fir.type +// CHECK-LABEL: func.func private @test_hfa_f128() -> !fir.array<4xf128> +func.func private @test_hfa_f128() -> !hfa_f128 +// CHECK-LABEL: func.func @test_call_hfa_f128( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) { +func.func @test_call_hfa_f128(%arg0 : !fir.ref) { + // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f128() : () -> !fir.array<4xf128> + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xf128> + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + %out = fir.call @test_hfa_f128() : () -> !hfa_f128 + // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref> + fir.store %out to %arg0 : !fir.ref + return +} + +!hfa_bf16 = !fir.type +// CHECK-LABEL: func.func private @test_hfa_bf16() -> !fir.array<4xbf16> +func.func private @test_hfa_bf16() -> !hfa_bf16 +// CHECK-LABEL: func.func @test_call_hfa_bf16( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) { +func.func @test_call_hfa_bf16(%arg0 : !fir.ref) { + // CHECK: %[[OUT:.*]] = fir.call @test_hfa_bf16() : () -> !fir.array<4xbf16> + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xbf16> + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + %out = fir.call @test_hfa_bf16() : () -> !hfa_bf16 + // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref> + fir.store %out to %arg0 : !fir.ref + return +} + +!too_big = !fir.type +// CHECK-LABEL: func.func private @test_too_big(!fir.ref> +// CHECK-SAME: {llvm.align = 8 : i32, llvm.sret = !fir.type}) +func.func private @test_too_big() -> !too_big +// CHECK-LABEL: func.func @test_call_too_big( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) { +func.func @test_call_too_big(%arg0 : !fir.ref) { + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARG:.*]] = fir.alloca !fir.type + // CHECK: fir.call @test_too_big(%[[ARG]]) : (!fir.ref>) -> () + // CHECK: %[[CVT:.*]] = fir.convert %[[ARG]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + %out = fir.call @test_too_big() : () -> !too_big + // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref> + fir.store %out to %arg0 : !fir.ref + return +} + + +!too_big_hfa = !fir.type}> +// CHECK-LABEL: func.func private @test_too_big_hfa(!fir.ref}>> +// CHECK-SAME: {llvm.align = 8 : i32, llvm.sret = !fir.type}>}) +func.func private @test_too_big_hfa() -> !too_big_hfa +// CHECK-LABEL: func.func @test_call_too_big_hfa( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref}>>) { +func.func @test_call_too_big_hfa(%arg0 : !fir.ref) { + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARG:.*]] = fir.alloca !fir.type}> + // CHECK: fir.call @test_too_big_hfa(%[[ARG]]) : (!fir.ref}>>) -> () + // CHECK: %[[CVT:.*]] = fir.convert %[[ARG]] : (!fir.ref}>>) -> !fir.ref}>> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref}>> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + %out = fir.call @test_too_big_hfa() : () -> !too_big_hfa + // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref}>> + fir.store %out to %arg0 : !fir.ref + return +} From da590f308a86a6485871c2e821244f9d94ff7ad6 Mon Sep 17 00:00:00 2001 From: David Truby Date: Wed, 30 Oct 2024 12:54:44 +0000 Subject: [PATCH 2/3] Fixes for review --- flang/lib/Optimizer/CodeGen/Target.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp index 15ffdb74ef51d..31a6a6f3aaa13 100644 --- a/flang/lib/Optimizer/CodeGen/Target.cpp +++ b/flang/lib/Optimizer/CodeGen/Target.cpp @@ -826,6 +826,8 @@ struct TargetAArch64 : public GenericTarget { return marshal; } + // Determine if the type is a Homogenous Floating-point Aggregate (HFA). An + // HFA is a record type with up to 4 floating-point members of the same type. static bool isHFA(fir::RecordType ty) { auto types = ty.getTypeList(); if (types.empty() || types.size() > 4) { @@ -833,13 +835,14 @@ struct TargetAArch64 : public GenericTarget { } if (!isa_real(types.front().second)) { - types.front().second.dump(); return false; } return llvm::all_equal(llvm::make_second_range(types)); } + // AArch64 procedure call ABI: + // https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing CodeGenSpecifics::Marshalling structReturnType(mlir::Location loc, fir::RecordType ty) const override { CodeGenSpecifics::Marshalling marshal; @@ -855,7 +858,7 @@ struct TargetAArch64 : public GenericTarget { // return in registers if size <= 16 bytes if (size <= 16) { - auto dwordSize = (size + 7) / 8; + std::size_t dwordSize = (size + 7) / 8; auto newTy = fir::SequenceType::get( dwordSize, mlir::IntegerType::get(ty.getContext(), 64)); marshal.emplace_back(newTy, AT{}); From 67e144028c77d0265028df03c26327b5361be1f5 Mon Sep 17 00:00:00 2001 From: David Truby Date: Tue, 12 Nov 2024 18:39:49 +0000 Subject: [PATCH 3/3] Fix isHFA for nested structure types --- flang/lib/Optimizer/CodeGen/Target.cpp | 58 +++++++++-- flang/test/Fir/struct-return-aarch64.fir | 123 ++++++++++++++++++----- 2 files changed, 150 insertions(+), 31 deletions(-) diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp index 31a6a6f3aaa13..06b19c0031e43 100644 --- a/flang/lib/Optimizer/CodeGen/Target.cpp +++ b/flang/lib/Optimizer/CodeGen/Target.cpp @@ -826,19 +826,65 @@ struct TargetAArch64 : public GenericTarget { return marshal; } + // Flatten a RecordType::TypeList containing more record types or array types + static std::optional> + flattenTypeList(const RecordType::TypeList &types) { + std::vector flatTypes; + // The flat list will be at least the same size as the non-flat list. + flatTypes.reserve(types.size()); + for (auto [c, type] : types) { + // Flatten record type + if (auto recTy = mlir::dyn_cast(type)) { + auto subTypeList = flattenTypeList(recTy.getTypeList()); + if (!subTypeList) + return std::nullopt; + llvm::copy(*subTypeList, std::back_inserter(flatTypes)); + continue; + } + + // Flatten array type + if (auto seqTy = mlir::dyn_cast(type)) { + if (seqTy.hasDynamicExtents()) + return std::nullopt; + std::size_t n = seqTy.getConstantArraySize(); + auto eleTy = seqTy.getElementType(); + // Flatten array of record types + if (auto recTy = mlir::dyn_cast(eleTy)) { + auto subTypeList = flattenTypeList(recTy.getTypeList()); + if (!subTypeList) + return std::nullopt; + for (std::size_t i = 0; i < n; ++i) + llvm::copy(*subTypeList, std::back_inserter(flatTypes)); + } else { + std::fill_n(std::back_inserter(flatTypes), + seqTy.getConstantArraySize(), eleTy); + } + continue; + } + + // Other types are already flat + flatTypes.push_back(type); + } + return flatTypes; + } + // Determine if the type is a Homogenous Floating-point Aggregate (HFA). An // HFA is a record type with up to 4 floating-point members of the same type. static bool isHFA(fir::RecordType ty) { - auto types = ty.getTypeList(); - if (types.empty() || types.size() > 4) { + RecordType::TypeList types = ty.getTypeList(); + if (types.empty() || types.size() > 4) + return false; + + std::optional> flatTypes = flattenTypeList(types); + if (!flatTypes || flatTypes->size() > 4) { return false; } - if (!isa_real(types.front().second)) { + if (!isa_real(flatTypes->front())) { return false; } - return llvm::all_equal(llvm::make_second_range(types)); + return llvm::all_equal(*flatTypes); } // AArch64 procedure call ABI: @@ -848,8 +894,8 @@ struct TargetAArch64 : public GenericTarget { CodeGenSpecifics::Marshalling marshal; if (isHFA(ty)) { - auto newTy = fir::SequenceType::get({ty.getNumFields()}, ty.getType(0)); - marshal.emplace_back(newTy, AT{}); + // Just return the existing record type + marshal.emplace_back(ty, AT{}); return marshal; } diff --git a/flang/test/Fir/struct-return-aarch64.fir b/flang/test/Fir/struct-return-aarch64.fir index 96f2f9999b343..8b75c2cac7b6b 100644 --- a/flang/test/Fir/struct-return-aarch64.fir +++ b/flang/test/Fir/struct-return-aarch64.fir @@ -22,16 +22,16 @@ func.func @test_call_composite(%arg0 : !fir.ref) { } !hfa_f16 = !fir.type -// CHECK-LABEL: func.func private @test_hfa_f16() -> !fir.array<2xf16> +// CHECK-LABEL: func.func private @test_hfa_f16() -> !fir.type func.func private @test_hfa_f16() -> !hfa_f16 // CHECK-LABEL: func.func @test_call_hfa_f16( // CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) { func.func @test_call_hfa_f16(%arg0 : !fir.ref) { - // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f16() : () -> !fir.array<2xf16> + // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f16() : () -> !fir.type // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr - // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<2xf16> - // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> - // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr %out = fir.call @test_hfa_f16() : () -> !hfa_f16 @@ -41,16 +41,16 @@ func.func @test_call_hfa_f16(%arg0 : !fir.ref) { } !hfa_f32 = !fir.type -// CHECK-LABEL: func.func private @test_hfa_f32() -> !fir.array<4xf32> +// CHECK-LABEL: func.func private @test_hfa_f32() -> !fir.type func.func private @test_hfa_f32() -> !hfa_f32 // CHECK-LABEL: func.func @test_call_hfa_f32( // CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) { func.func @test_call_hfa_f32(%arg0 : !fir.ref) { - // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f32() : () -> !fir.array<4xf32> + // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f32() : () -> !fir.type // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr - // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xf32> - // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> - // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr %out = fir.call @test_hfa_f32() : () -> !hfa_f32 @@ -60,16 +60,16 @@ func.func @test_call_hfa_f32(%arg0 : !fir.ref) { } !hfa_f64 = !fir.type -// CHECK-LABEL: func.func private @test_hfa_f64() -> !fir.array<3xf64> +// CHECK-LABEL: func.func private @test_hfa_f64() -> !fir.type func.func private @test_hfa_f64() -> !hfa_f64 // CHECK-LABEL: func.func @test_call_hfa_f64( // CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) func.func @test_call_hfa_f64(%arg0 : !fir.ref) { - // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f64() : () -> !fir.array<3xf64> + // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f64() : () -> !fir.type // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr - // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<3xf64> - // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> - // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr %out = fir.call @test_hfa_f64() : () -> !hfa_f64 @@ -79,16 +79,16 @@ func.func @test_call_hfa_f64(%arg0 : !fir.ref) { } !hfa_f128 = !fir.type -// CHECK-LABEL: func.func private @test_hfa_f128() -> !fir.array<4xf128> +// CHECK-LABEL: func.func private @test_hfa_f128() -> !fir.type func.func private @test_hfa_f128() -> !hfa_f128 // CHECK-LABEL: func.func @test_call_hfa_f128( // CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) { func.func @test_call_hfa_f128(%arg0 : !fir.ref) { - // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f128() : () -> !fir.array<4xf128> + // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f128() : () -> !fir.type // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr - // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xf128> - // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> - // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr %out = fir.call @test_hfa_f128() : () -> !hfa_f128 @@ -98,16 +98,16 @@ func.func @test_call_hfa_f128(%arg0 : !fir.ref) { } !hfa_bf16 = !fir.type -// CHECK-LABEL: func.func private @test_hfa_bf16() -> !fir.array<4xbf16> +// CHECK-LABEL: func.func private @test_hfa_bf16() -> !fir.type func.func private @test_hfa_bf16() -> !hfa_bf16 // CHECK-LABEL: func.func @test_call_hfa_bf16( // CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) { func.func @test_call_hfa_bf16(%arg0 : !fir.ref) { - // CHECK: %[[OUT:.*]] = fir.call @test_hfa_bf16() : () -> !fir.array<4xbf16> + // CHECK: %[[OUT:.*]] = fir.call @test_hfa_bf16() : () -> !fir.type // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr - // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xbf16> - // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> - // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr %out = fir.call @test_hfa_bf16() : () -> !hfa_bf16 @@ -154,3 +154,76 @@ func.func @test_call_too_big_hfa(%arg0 : !fir.ref) { fir.store %out to %arg0 : !fir.ref return } + +!nested_hfa_first = !fir.type +// CHECK-LABEL: func.func private @test_nested_hfa_first() -> !fir.type,c:f16}> +func.func private @test_nested_hfa_first() -> !nested_hfa_first +// CHECK-LABEL: func.func @test_call_nested_hfa_first(%arg0: !fir.ref,c:f16}>>) { +func.func @test_call_nested_hfa_first(%arg0 : !fir.ref) { + %out = fir.call @test_nested_hfa_first() : () -> !nested_hfa_first + // CHECK: %[[OUT:.*]] = fir.call @test_nested_hfa_first() : () -> !fir.type,c:f16}> + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type,c:f16}> + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref,c:f16}>> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref,c:f16}>> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref,c:f16}>> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + fir.store %out to %arg0 : !fir.ref + // CHECK fir.store %[[LD]] to %[[ARG0]] : !fir.ref,c:f16}>> + return +} + + +!nested_hfa_middle = !fir.type +// CHECK-LABEL: func.func private @test_nested_hfa_middle() -> !fir.type,c:f16}> +func.func private @test_nested_hfa_middle() -> !nested_hfa_middle +// CHECK-LABEL: func.func @test_call_nested_hfa_middle(%arg0: !fir.ref,c:f16}>>) { +func.func @test_call_nested_hfa_middle(%arg0 : !fir.ref) { + %out = fir.call @test_nested_hfa_middle() : () -> !nested_hfa_middle + // CHECK: %[[OUT:.*]] = fir.call @test_nested_hfa_middle() : () -> !fir.type,c:f16}> + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type,c:f16}> + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref,c:f16}>> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref,c:f16}>> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref,c:f16}>> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + fir.store %out to %arg0 : !fir.ref + // CHECK fir.store %[[LD]] to %[[ARG0]] : !fir.ref,c:f16}>> + return +} + +!nested_hfa_end = !fir.type +// CHECK-LABEL: func.func private @test_nested_hfa_end() -> !fir.type}> +func.func private @test_nested_hfa_end() -> !nested_hfa_end +// CHECK-LABEL: func.func @test_call_nested_hfa_end(%arg0: !fir.ref}>>) { +func.func @test_call_nested_hfa_end(%arg0 : !fir.ref) { + %out = fir.call @test_nested_hfa_end() : () -> !nested_hfa_end + // CHECK: %[[OUT:.*]] = fir.call @test_nested_hfa_end() : () -> !fir.type}> + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type}> + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref}>> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref}>> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref}>> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + fir.store %out to %arg0 : !fir.ref + // CHECK fir.store %[[LD]] to %[[ARG0]] : !fir.ref}>> + return +} + +!nested_hfa_array = !fir.type,b:f32}> +// CHECK-LABEL: func.func private @test_nested_hfa_array() -> !fir.type,b:f32}> +func.func private @test_nested_hfa_array() -> !nested_hfa_array +// CHECK-LABEL: func.func @test_call_nested_hfa_array(%arg0: !fir.ref,b:f32}> +func.func @test_call_nested_hfa_array(%arg0 : !fir.ref) { + %out = fir.call @test_nested_hfa_array() : () -> !nested_hfa_array + // CHECK: %[[OUT:.*]] = fir.call @test_nested_hfa_array() : () -> !fir.type,b:f32}> + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type,b:f32}> + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref,b:f32}> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref,b:f32}> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref,b:f32}> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + fir.store %out to %arg0 : !fir.ref + // CHECK fir.store %[[LD]] to %[[ARG0]] : !fir.ref,b:f32}> + return +}