Skip to content

Commit 2a6e834

Browse files
committed
[flang] add ABI argument attributes in indirect calls
1 parent 25c0554 commit 2a6e834

File tree

4 files changed

+100
-11
lines changed

4 files changed

+100
-11
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,8 +593,36 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
593593
call, resultTys, adaptor.getOperands(),
594594
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
595595
adaptor.getOperands().size()));
596-
if (mlir::ArrayAttr argAttrs = call.getArgAttrsAttr())
597-
llvmCall.setArgAttrsAttr(argAttrs);
596+
if (mlir::ArrayAttr argAttrsArray = call.getArgAttrsAttr()) {
597+
// sret and byval type needs to be converted.
598+
auto convertTypeAttr = [&](const mlir::NamedAttribute &attr) {
599+
return mlir::TypeAttr::get(convertType(
600+
llvm::cast<mlir::TypeAttr>(attr.getValue()).getValue()));
601+
};
602+
llvm::SmallVector<mlir::Attribute> newArgAttrsArray;
603+
for (auto argAttrs : argAttrsArray) {
604+
llvm::SmallVector<mlir::NamedAttribute> convertedAttrs;
605+
for (const mlir::NamedAttribute &attr :
606+
llvm::cast<mlir::DictionaryAttr>(argAttrs)) {
607+
if (attr.getName().getValue() ==
608+
mlir::LLVM::LLVMDialect::getByValAttrName()) {
609+
convertedAttrs.push_back(rewriter.getNamedAttr(
610+
mlir::LLVM::LLVMDialect::getByValAttrName(),
611+
convertTypeAttr(attr)));
612+
} else if (attr.getName().getValue() ==
613+
mlir::LLVM::LLVMDialect::getStructRetAttrName()) {
614+
convertedAttrs.push_back(rewriter.getNamedAttr(
615+
mlir::LLVM::LLVMDialect::getStructRetAttrName(),
616+
convertTypeAttr(attr)));
617+
} else {
618+
convertedAttrs.push_back(attr);
619+
}
620+
}
621+
newArgAttrsArray.emplace_back(
622+
mlir::DictionaryAttr::get(rewriter.getContext(), convertedAttrs));
623+
}
624+
llvmCall.setArgAttrsAttr(rewriter.getArrayAttr(newArgAttrsArray));
625+
}
598626
if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
599627
llvmCall.setResAttrsAttr(resAttrs);
600628
return mlir::success();

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -534,19 +534,44 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
534534
} else if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
535535
fir::CallOp newCall;
536536
if (callOp.getCallee()) {
537-
newCall =
538-
rewriter->create<A>(loc, *callOp.getCallee(), newResTys, newOpers);
537+
newCall = rewriter->create<fir::CallOp>(loc, *callOp.getCallee(),
538+
newResTys, newOpers);
539539
} else {
540-
// TODO: llvm dialect must be updated to propagate argument on
541-
// attributes for indirect calls. See:
542-
// https://discourse.llvm.org/t/should-llvm-callop-be-able-to-carry-argument-attributes-for-indirect-calls/75431
543-
if (hasByValOrSRetArgs(newInTyAndAttrs))
544-
TODO(loc,
545-
"passing argument or result on the stack in indirect calls");
546540
newOpers[0].setType(mlir::FunctionType::get(
547541
callOp.getContext(),
548542
mlir::TypeRange{newInTypes}.drop_front(dropFront), newResTys));
549-
newCall = rewriter->create<A>(loc, newResTys, newOpers);
543+
newCall = rewriter->create<fir::CallOp>(loc, newResTys, newOpers);
544+
// Set ABI argument attributes on call operation since they are not
545+
// accessible via a FuncOp in indirect calls.
546+
if (hasByValOrSRetArgs(newInTyAndAttrs)) {
547+
llvm::SmallVector<mlir::Attribute> argAttrsArray;
548+
for (const auto &arg :
549+
llvm::ArrayRef<fir::CodeGenSpecifics::TypeAndAttr>(
550+
newInTyAndAttrs)
551+
.drop_front(dropFront)) {
552+
mlir::NamedAttrList argAttrs;
553+
const auto &attr = std::get<fir::CodeGenSpecifics::Attributes>(arg);
554+
if (attr.isByVal()) {
555+
mlir::Type elemType =
556+
fir::dyn_cast_ptrOrBoxEleTy(std::get<mlir::Type>(arg));
557+
argAttrs.set(mlir::LLVM::LLVMDialect::getByValAttrName(),
558+
mlir::TypeAttr::get(elemType));
559+
} else if (attr.isSRet()) {
560+
mlir::Type elemType =
561+
fir::dyn_cast_ptrOrBoxEleTy(std::get<mlir::Type>(arg));
562+
argAttrs.set(mlir::LLVM::LLVMDialect::getStructRetAttrName(),
563+
mlir::TypeAttr::get(elemType));
564+
if (auto align = attr.getAlignment()) {
565+
argAttrs.set(mlir::LLVM::LLVMDialect::getAlignAttrName(),
566+
rewriter->getIntegerAttr(
567+
rewriter->getIntegerType(32), align));
568+
}
569+
}
570+
argAttrsArray.emplace_back(
571+
argAttrs.getDictionary(rewriter->getContext()));
572+
}
573+
newCall.setArgAttrsAttr(rewriter->getArrayAttr(argAttrsArray));
574+
}
550575
}
551576
LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
552577
if (wrap)

flang/test/Fir/convert-to-llvm.fir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2871,3 +2871,17 @@ func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
28712871
%0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
28722872
return %0 : i16
28732873
}
2874+
2875+
// CHECK-LABEL: @test_byval
2876+
func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
2877+
// llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.byval = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
2878+
fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.byval = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
2879+
return
2880+
}
2881+
2882+
// CHECK-LABEL: @test_sret
2883+
func.func @test_sret(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
2884+
// llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.sret = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
2885+
fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.sret = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
2886+
return
2887+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Test that ABI attributes are set in indirect calls to BIND(C) functions.
2+
// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
3+
4+
func.func @test(%arg0: () -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
5+
%0 = fir.load %arg1 : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
6+
%1 = fir.convert %arg0 : (() -> ()) -> ((!fir.type<t{a:!fir.array<5xf64>}>, f64) -> ())
7+
fir.call %1(%0, %arg2) proc_attrs<bind_c> : (!fir.type<t{a:!fir.array<5xf64>}>, f64) -> ()
8+
return
9+
}
10+
// CHECK-LABEL: func.func @test(
11+
// CHECK-SAME: %[[VAL_0:.*]]: () -> (),
12+
// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>,
13+
// CHECK-SAME: %[[VAL_2:.*]]: f64) {
14+
// CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
15+
// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_0]] : (() -> ()) -> ((!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> ())
16+
// CHECK: %[[VAL_5:.*]] = llvm.intr.stacksave : !llvm.ptr
17+
// CHECK: %[[VAL_6:.*]] = fir.alloca !fir.type<t{a:!fir.array<5xf64>}>
18+
// CHECK: fir.store %[[VAL_3]] to %[[VAL_6]] : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
19+
// CHECK: fir.call %[[VAL_4]](%[[VAL_6]], %[[VAL_2]]) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.byval = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
20+
// CHECK: llvm.intr.stackrestore %[[VAL_5]] : !llvm.ptr
21+
// CHECK: return
22+
// CHECK: }

0 commit comments

Comments
 (0)