diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp index 6c148dffb0e55..06b19c0031e43 100644 --- a/flang/lib/Optimizer/CodeGen/Target.cpp +++ b/flang/lib/Optimizer/CodeGen/Target.cpp @@ -825,6 +825,97 @@ struct TargetAArch64 : public GenericTarget { } return marshal; } + + // Flatten a RecordType::TypeList containing more record types or array types + static std::optional> + flattenTypeList(const RecordType::TypeList &types) { + std::vector flatTypes; + // The flat list will be at least the same size as the non-flat list. + flatTypes.reserve(types.size()); + for (auto [c, type] : types) { + // Flatten record type + if (auto recTy = mlir::dyn_cast(type)) { + auto subTypeList = flattenTypeList(recTy.getTypeList()); + if (!subTypeList) + return std::nullopt; + llvm::copy(*subTypeList, std::back_inserter(flatTypes)); + continue; + } + + // Flatten array type + if (auto seqTy = mlir::dyn_cast(type)) { + if (seqTy.hasDynamicExtents()) + return std::nullopt; + std::size_t n = seqTy.getConstantArraySize(); + auto eleTy = seqTy.getElementType(); + // Flatten array of record types + if (auto recTy = mlir::dyn_cast(eleTy)) { + auto subTypeList = flattenTypeList(recTy.getTypeList()); + if (!subTypeList) + return std::nullopt; + for (std::size_t i = 0; i < n; ++i) + llvm::copy(*subTypeList, std::back_inserter(flatTypes)); + } else { + std::fill_n(std::back_inserter(flatTypes), + seqTy.getConstantArraySize(), eleTy); + } + continue; + } + + // Other types are already flat + flatTypes.push_back(type); + } + return flatTypes; + } + + // Determine if the type is a Homogenous Floating-point Aggregate (HFA). An + // HFA is a record type with up to 4 floating-point members of the same type. + static bool isHFA(fir::RecordType ty) { + RecordType::TypeList types = ty.getTypeList(); + if (types.empty() || types.size() > 4) + return false; + + std::optional> flatTypes = flattenTypeList(types); + if (!flatTypes || flatTypes->size() > 4) { + return false; + } + + if (!isa_real(flatTypes->front())) { + return false; + } + + return llvm::all_equal(*flatTypes); + } + + // AArch64 procedure call ABI: + // https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing + CodeGenSpecifics::Marshalling + structReturnType(mlir::Location loc, fir::RecordType ty) const override { + CodeGenSpecifics::Marshalling marshal; + + if (isHFA(ty)) { + // Just return the existing record type + marshal.emplace_back(ty, AT{}); + return marshal; + } + + auto [size, align] = + fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap); + + // return in registers if size <= 16 bytes + if (size <= 16) { + std::size_t dwordSize = (size + 7) / 8; + auto newTy = fir::SequenceType::get( + dwordSize, mlir::IntegerType::get(ty.getContext(), 64)); + marshal.emplace_back(newTy, AT{}); + return marshal; + } + + unsigned short stackAlign = std::max(align, 8u); + marshal.emplace_back(fir::ReferenceType::get(ty), + AT{stackAlign, false, true}); + return marshal; + } }; } // namespace diff --git a/flang/test/Fir/struct-return-aarch64.fir b/flang/test/Fir/struct-return-aarch64.fir new file mode 100644 index 0000000000000..8b75c2cac7b6b --- /dev/null +++ b/flang/test/Fir/struct-return-aarch64.fir @@ -0,0 +1,229 @@ +// Test AArch64 ABI rewrite of struct returned by value (BIND(C), VALUE derived types). +// RUN: fir-opt --target-rewrite="target=aarch64-unknown-linux-gnu" %s | FileCheck %s + +!composite = !fir.type +// CHECK-LABEL: func.func private @test_composite() -> !fir.array<2xi64> +func.func private @test_composite() -> !composite +// CHECK-LABEL: func.func @test_call_composite( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) +func.func @test_call_composite(%arg0 : !fir.ref) { + // CHECK: %[[OUT:.*]] = fir.call @test_composite() : () -> !fir.array<2xi64> + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<2xi64> + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + %out = fir.call @test_composite() : () -> !composite + // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref> + fir.store %out to %arg0 : !fir.ref + // CHECK: return + return +} + +!hfa_f16 = !fir.type +// CHECK-LABEL: func.func private @test_hfa_f16() -> !fir.type +func.func private @test_hfa_f16() -> !hfa_f16 +// CHECK-LABEL: func.func @test_call_hfa_f16( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) { +func.func @test_call_hfa_f16(%arg0 : !fir.ref) { + // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f16() : () -> !fir.type + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + %out = fir.call @test_hfa_f16() : () -> !hfa_f16 + // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref> + fir.store %out to %arg0 : !fir.ref + return +} + +!hfa_f32 = !fir.type +// CHECK-LABEL: func.func private @test_hfa_f32() -> !fir.type +func.func private @test_hfa_f32() -> !hfa_f32 +// CHECK-LABEL: func.func @test_call_hfa_f32( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) { +func.func @test_call_hfa_f32(%arg0 : !fir.ref) { + // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f32() : () -> !fir.type + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + %out = fir.call @test_hfa_f32() : () -> !hfa_f32 + // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref> + fir.store %out to %arg0 : !fir.ref + return +} + +!hfa_f64 = !fir.type +// CHECK-LABEL: func.func private @test_hfa_f64() -> !fir.type +func.func private @test_hfa_f64() -> !hfa_f64 +// CHECK-LABEL: func.func @test_call_hfa_f64( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) +func.func @test_call_hfa_f64(%arg0 : !fir.ref) { + // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f64() : () -> !fir.type + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + %out = fir.call @test_hfa_f64() : () -> !hfa_f64 + // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref> + fir.store %out to %arg0 : !fir.ref + return +} + +!hfa_f128 = !fir.type +// CHECK-LABEL: func.func private @test_hfa_f128() -> !fir.type +func.func private @test_hfa_f128() -> !hfa_f128 +// CHECK-LABEL: func.func @test_call_hfa_f128( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) { +func.func @test_call_hfa_f128(%arg0 : !fir.ref) { + // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f128() : () -> !fir.type + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + %out = fir.call @test_hfa_f128() : () -> !hfa_f128 + // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref> + fir.store %out to %arg0 : !fir.ref + return +} + +!hfa_bf16 = !fir.type +// CHECK-LABEL: func.func private @test_hfa_bf16() -> !fir.type +func.func private @test_hfa_bf16() -> !hfa_bf16 +// CHECK-LABEL: func.func @test_call_hfa_bf16( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) { +func.func @test_call_hfa_bf16(%arg0 : !fir.ref) { + // CHECK: %[[OUT:.*]] = fir.call @test_hfa_bf16() : () -> !fir.type + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + %out = fir.call @test_hfa_bf16() : () -> !hfa_bf16 + // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref> + fir.store %out to %arg0 : !fir.ref + return +} + +!too_big = !fir.type +// CHECK-LABEL: func.func private @test_too_big(!fir.ref> +// CHECK-SAME: {llvm.align = 8 : i32, llvm.sret = !fir.type}) +func.func private @test_too_big() -> !too_big +// CHECK-LABEL: func.func @test_call_too_big( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref>) { +func.func @test_call_too_big(%arg0 : !fir.ref) { + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARG:.*]] = fir.alloca !fir.type + // CHECK: fir.call @test_too_big(%[[ARG]]) : (!fir.ref>) -> () + // CHECK: %[[CVT:.*]] = fir.convert %[[ARG]] : (!fir.ref>) -> !fir.ref> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + %out = fir.call @test_too_big() : () -> !too_big + // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref> + fir.store %out to %arg0 : !fir.ref + return +} + + +!too_big_hfa = !fir.type}> +// CHECK-LABEL: func.func private @test_too_big_hfa(!fir.ref}>> +// CHECK-SAME: {llvm.align = 8 : i32, llvm.sret = !fir.type}>}) +func.func private @test_too_big_hfa() -> !too_big_hfa +// CHECK-LABEL: func.func @test_call_too_big_hfa( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref}>>) { +func.func @test_call_too_big_hfa(%arg0 : !fir.ref) { + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARG:.*]] = fir.alloca !fir.type}> + // CHECK: fir.call @test_too_big_hfa(%[[ARG]]) : (!fir.ref}>>) -> () + // CHECK: %[[CVT:.*]] = fir.convert %[[ARG]] : (!fir.ref}>>) -> !fir.ref}>> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref}>> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + %out = fir.call @test_too_big_hfa() : () -> !too_big_hfa + // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref}>> + fir.store %out to %arg0 : !fir.ref + return +} + +!nested_hfa_first = !fir.type +// CHECK-LABEL: func.func private @test_nested_hfa_first() -> !fir.type,c:f16}> +func.func private @test_nested_hfa_first() -> !nested_hfa_first +// CHECK-LABEL: func.func @test_call_nested_hfa_first(%arg0: !fir.ref,c:f16}>>) { +func.func @test_call_nested_hfa_first(%arg0 : !fir.ref) { + %out = fir.call @test_nested_hfa_first() : () -> !nested_hfa_first + // CHECK: %[[OUT:.*]] = fir.call @test_nested_hfa_first() : () -> !fir.type,c:f16}> + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type,c:f16}> + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref,c:f16}>> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref,c:f16}>> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref,c:f16}>> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + fir.store %out to %arg0 : !fir.ref + // CHECK fir.store %[[LD]] to %[[ARG0]] : !fir.ref,c:f16}>> + return +} + + +!nested_hfa_middle = !fir.type +// CHECK-LABEL: func.func private @test_nested_hfa_middle() -> !fir.type,c:f16}> +func.func private @test_nested_hfa_middle() -> !nested_hfa_middle +// CHECK-LABEL: func.func @test_call_nested_hfa_middle(%arg0: !fir.ref,c:f16}>>) { +func.func @test_call_nested_hfa_middle(%arg0 : !fir.ref) { + %out = fir.call @test_nested_hfa_middle() : () -> !nested_hfa_middle + // CHECK: %[[OUT:.*]] = fir.call @test_nested_hfa_middle() : () -> !fir.type,c:f16}> + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type,c:f16}> + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref,c:f16}>> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref,c:f16}>> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref,c:f16}>> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + fir.store %out to %arg0 : !fir.ref + // CHECK fir.store %[[LD]] to %[[ARG0]] : !fir.ref,c:f16}>> + return +} + +!nested_hfa_end = !fir.type +// CHECK-LABEL: func.func private @test_nested_hfa_end() -> !fir.type}> +func.func private @test_nested_hfa_end() -> !nested_hfa_end +// CHECK-LABEL: func.func @test_call_nested_hfa_end(%arg0: !fir.ref}>>) { +func.func @test_call_nested_hfa_end(%arg0 : !fir.ref) { + %out = fir.call @test_nested_hfa_end() : () -> !nested_hfa_end + // CHECK: %[[OUT:.*]] = fir.call @test_nested_hfa_end() : () -> !fir.type}> + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type}> + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref}>> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref}>> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref}>> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + fir.store %out to %arg0 : !fir.ref + // CHECK fir.store %[[LD]] to %[[ARG0]] : !fir.ref}>> + return +} + +!nested_hfa_array = !fir.type,b:f32}> +// CHECK-LABEL: func.func private @test_nested_hfa_array() -> !fir.type,b:f32}> +func.func private @test_nested_hfa_array() -> !nested_hfa_array +// CHECK-LABEL: func.func @test_call_nested_hfa_array(%arg0: !fir.ref,b:f32}> +func.func @test_call_nested_hfa_array(%arg0 : !fir.ref) { + %out = fir.call @test_nested_hfa_array() : () -> !nested_hfa_array + // CHECK: %[[OUT:.*]] = fir.call @test_nested_hfa_array() : () -> !fir.type,b:f32}> + // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr + // CHECK: %[[ARR:.*]] = fir.alloca !fir.type,b:f32}> + // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref,b:f32}> + // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref,b:f32}> + // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref,b:f32}> + // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr + fir.store %out to %arg0 : !fir.ref + // CHECK fir.store %[[LD]] to %[[ARG0]] : !fir.ref,b:f32}> + return +}