Skip to content

Commit 65075a8

Browse files
authored
[flang][FIR] handle argument attributes in fir.call (#126711)
Add pretty printer/parser for fir.call argument/result attributes and propagate them to llvm.call. This will allow implementing the TODO about ABI relevant argument attribute in indirect calls.
1 parent 39f0f0a commit 65075a8

File tree

4 files changed

+62
-13
lines changed

4 files changed

+62
-13
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,10 +589,14 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
589589
// Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
590590
mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
591591
attrConvert(call);
592-
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
592+
auto llvmCall = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
593593
call, resultTys, adaptor.getOperands(),
594594
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
595595
adaptor.getOperands().size()));
596+
if (mlir::ArrayAttr argAttrs = call.getArgAttrsAttr())
597+
llvmCall.setArgAttrsAttr(argAttrs);
598+
if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
599+
llvmCall.setResAttrsAttr(resAttrs);
596600
return mlir::success();
597601
}
598602
};

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,11 +1121,12 @@ void fir::CallOp::print(mlir::OpAsmPrinter &p) {
11211121

11221122
p.printOptionalAttrDict((*this)->getAttrs(),
11231123
{fir::CallOp::getCalleeAttrNameStr(),
1124-
getFastmathAttrName(), getProcedureAttrsAttrName()});
1125-
auto resultTypes{getResultTypes()};
1126-
llvm::SmallVector<mlir::Type> argTypes(
1127-
llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1));
1128-
p << " : " << mlir::FunctionType::get(getContext(), argTypes, resultTypes);
1124+
getFastmathAttrName(), getProcedureAttrsAttrName(),
1125+
getArgAttrsAttrName(), getResAttrsAttrName()});
1126+
p << " : ";
1127+
mlir::call_interface_impl::printFunctionSignature(
1128+
p, getArgs().drop_front(isDirect ? 0 : 1).getTypes(), getArgAttrsAttr(),
1129+
/*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
11291130
}
11301131

11311132
mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
@@ -1142,7 +1143,6 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
11421143
attrs))
11431144
return mlir::failure();
11441145

1145-
mlir::Type type;
11461146
if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren))
11471147
return mlir::failure();
11481148

@@ -1163,13 +1163,17 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
11631163
fmfAttrName, attrs))
11641164
return mlir::failure();
11651165

1166-
if (parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
1167-
parser.parseType(type))
1166+
if (parser.parseOptionalAttrDict(attrs) || parser.parseColon())
11681167
return mlir::failure();
1169-
1170-
auto funcType = mlir::dyn_cast<mlir::FunctionType>(type);
1171-
if (!funcType)
1168+
llvm::SmallVector<mlir::Type> argTypes;
1169+
llvm::SmallVector<mlir::Type> resTypes;
1170+
llvm::SmallVector<mlir::DictionaryAttr> argAttrs;
1171+
llvm::SmallVector<mlir::DictionaryAttr> resultAttrs;
1172+
if (mlir::call_interface_impl::parseFunctionSignature(
1173+
parser, argTypes, argAttrs, resTypes, resultAttrs))
11721174
return parser.emitError(parser.getNameLoc(), "expected function type");
1175+
mlir::FunctionType funcType =
1176+
mlir::FunctionType::get(parser.getContext(), argTypes, resTypes);
11731177
if (isDirect) {
11741178
if (parser.resolveOperands(operands, funcType.getInputs(),
11751179
parser.getNameLoc(), result.operands))
@@ -1183,8 +1187,11 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
11831187
parser.getNameLoc(), result.operands))
11841188
return mlir::failure();
11851189
}
1186-
result.addTypes(funcType.getResults());
11871190
result.attributes = attrs;
1191+
mlir::call_interface_impl::addArgAndResultAttrs(
1192+
parser.getBuilder(), result, argAttrs, resultAttrs,
1193+
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1194+
result.addTypes(funcType.getResults());
11881195
return mlir::success();
11891196
}
11901197

flang/test/Fir/convert-to-llvm.fir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2853,3 +2853,21 @@ gpu.module @cuda_device_mod {
28532853
// CHECK: llvm.func @malloc(i64) -> !llvm.ptr
28542854
// CHECK: llvm.call @malloc
28552855
// CHECK: lvm.call @free
2856+
2857+
// -----
2858+
2859+
func.func private @somefunc(i32, !fir.ref<i64>)
2860+
2861+
// CHECK-LABEL: @test_call_arg_attrs_direct
2862+
func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
2863+
// CHECK: llvm.call @somefunc(%{{.*}}, %{{.*}}) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
2864+
fir.call @somefunc(%arg0, %arg1) : (i32, !fir.ref<i64> {llvm.byval = i64}) -> ()
2865+
return
2866+
}
2867+
2868+
// CHECK-LABEL: @test_call_arg_attrs_indirect
2869+
func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
2870+
// CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
2871+
%0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
2872+
return %0 : i16
2873+
}

flang/test/Fir/fir-ops.fir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,3 +913,23 @@ func.func @test_is_assumed_size(%arg0: !fir.class<!fir.array<*:none>>, %arg1 : !
913913
// CHECK-SAME: %[[B:.*]]: !fir.box<!fir.array<?xf32>>)
914914
// CHECK: fir.is_assumed_size %[[A]] : (!fir.class<!fir.array<*:none>>) -> i1
915915
// CHECK: fir.is_assumed_size %[[B]] : (!fir.box<!fir.array<?xf32>>) -> i1
916+
917+
func.func private @somefunc(i32, !fir.ref<i64>)
918+
919+
// CHECK-LABEL: @test_call_arg_attrs_direct
920+
// CHECK-SAME: %[[VAL_0:.*]]: i32,
921+
// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<i64>) {
922+
func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
923+
// CHECK: fir.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !fir.ref<i64> {llvm.byval = i64}) -> ()
924+
fir.call @somefunc(%arg0, %arg1) : (i32, !fir.ref<i64> {llvm.byval = i64}) -> ()
925+
return
926+
}
927+
928+
// CHECK-LABEL: @test_call_arg_attrs_indirect
929+
// CHECK-SAME: %[[VAL_0:.*]]: i16,
930+
// CHECK-SAME: %[[VAL_1:.*]]: (i16) -> i16) -> i16 {
931+
func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
932+
// CHECK: fir.call %[[VAL_1]](%[[VAL_0]]) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
933+
%0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
934+
return %0 : i16
935+
}

0 commit comments

Comments
 (0)