-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[flang][FIR] handle argument attributes in fir.call #126711
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-flang-codegen Author: None (jeanPerier) ChangesAdd pretty printer/parser for fir.call argument/result attributes and propagate them to llvm.call. This will allow implementing the TODO about ABI relevant argument attribute in indirect calls here. Full diff: https://github.com/llvm/llvm-project/pull/126711.diff 4 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index cb4eb8303a4959e..6346ee0d35292c6 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -589,10 +589,14 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
// Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
attrConvert(call);
- rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
+ auto llvmCall = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
call, resultTys, adaptor.getOperands(),
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
adaptor.getOperands().size()));
+ if (mlir::ArrayAttr argAttrs = call.getArgAttrsAttr())
+ llvmCall.setArgAttrsAttr(argAttrs);
+ if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
+ llvmCall.setArgAttrsAttr(resAttrs);
return mlir::success();
}
};
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index fa83aa380e489c3..7e50622db08c98b 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1121,11 +1121,12 @@ void fir::CallOp::print(mlir::OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs(),
{fir::CallOp::getCalleeAttrNameStr(),
- getFastmathAttrName(), getProcedureAttrsAttrName()});
- auto resultTypes{getResultTypes()};
- llvm::SmallVector<mlir::Type> argTypes(
- llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1));
- p << " : " << mlir::FunctionType::get(getContext(), argTypes, resultTypes);
+ getFastmathAttrName(), getProcedureAttrsAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName()});
+ p << " : ";
+ mlir::call_interface_impl::printFunctionSignature(
+ p, getArgs().drop_front(isDirect ? 0 : 1).getTypes(), getArgAttrsAttr(),
+ /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
}
mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
@@ -1142,7 +1143,6 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
attrs))
return mlir::failure();
- mlir::Type type;
if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren))
return mlir::failure();
@@ -1163,13 +1163,17 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
fmfAttrName, attrs))
return mlir::failure();
- if (parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
- parser.parseType(type))
+ if (parser.parseOptionalAttrDict(attrs) || parser.parseColon())
return mlir::failure();
-
- auto funcType = mlir::dyn_cast<mlir::FunctionType>(type);
- if (!funcType)
+ llvm::SmallVector<mlir::Type> argTypes;
+ llvm::SmallVector<mlir::Type> resTypes;
+ llvm::SmallVector<mlir::DictionaryAttr> argAttrs;
+ llvm::SmallVector<mlir::DictionaryAttr> resultAttrs;
+ if (mlir::call_interface_impl::parseFunctionSignature(
+ parser, argTypes, argAttrs, resTypes, resultAttrs))
return parser.emitError(parser.getNameLoc(), "expected function type");
+ mlir::FunctionType funcType =
+ mlir::FunctionType::get(parser.getContext(), argTypes, resTypes);
if (isDirect) {
if (parser.resolveOperands(operands, funcType.getInputs(),
parser.getNameLoc(), result.operands))
@@ -1183,8 +1187,11 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
parser.getNameLoc(), result.operands))
return mlir::failure();
}
- result.addTypes(funcType.getResults());
result.attributes = attrs;
+ mlir::call_interface_impl::addArgAndResultAttrs(
+ parser.getBuilder(), result, argAttrs, resultAttrs,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ result.addTypes(funcType.getResults());
return mlir::success();
}
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 6d7a4a09918e5a2..a4c176a9e2ee86e 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2853,3 +2853,21 @@ gpu.module @cuda_device_mod {
// CHECK: llvm.func @malloc(i64) -> !llvm.ptr
// CHECK: llvm.call @malloc
// CHECK: lvm.call @free
+
+// -----
+
+func.func private @somefunc(i32, !fir.ref<i64>)
+
+// CHECK-LABEL: @test_call_arg_attrs_direct
+func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
+ // CHECK: llvm.call @somefunc(%{{.*}}, %{{.*}}) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
+ fir.call @somefunc(%arg0, %arg1) : (i32, !fir.ref<i64> {llvm.byval = i64}) -> ()
+ return
+}
+
+// CHECK-LABEL: @test_call_arg_attrs_indirect
+func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
+ // CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.signext}) -> (i16 {llvm.signext})
+ %0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+ return %0 : i16
+}
diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir
index 5a30858511f0c9a..1bfcb3a9f3dc898 100644
--- a/flang/test/Fir/fir-ops.fir
+++ b/flang/test/Fir/fir-ops.fir
@@ -913,3 +913,23 @@ func.func @test_is_assumed_size(%arg0: !fir.class<!fir.array<*:none>>, %arg1 : !
// CHECK-SAME: %[[B:.*]]: !fir.box<!fir.array<?xf32>>)
// CHECK: fir.is_assumed_size %[[A]] : (!fir.class<!fir.array<*:none>>) -> i1
// CHECK: fir.is_assumed_size %[[B]] : (!fir.box<!fir.array<?xf32>>) -> i1
+
+func.func private @somefunc(i32, !fir.ref<i64>)
+
+// CHECK-LABEL: @test_call_arg_attrs_direct
+// CHECK-SAME: %[[VAL_0:.*]]: i32,
+// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<i64>) {
+func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
+ // CHECK: fir.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !fir.ref<i64> {llvm.byval = i64}) -> ()
+ fir.call @somefunc(%arg0, %arg1) : (i32, !fir.ref<i64> {llvm.byval = i64}) -> ()
+ return
+}
+
+// CHECK-LABEL: @test_call_arg_attrs_indirect
+// CHECK-SAME: %[[VAL_0:.*]]: i16,
+// CHECK-SAME: %[[VAL_1:.*]]: (i16) -> i16) -> i16 {
+func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
+ // CHECK: fir.call %[[VAL_1]](%[[VAL_0]]) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+ %0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+ return %0 : i16
+}
|
vzakhari
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, Jean!
LGTM, just one question.
Add pretty printer/parser for fir.call argument/result attributes and propagate them to llvm.call. This will allow implementing the TODO about ABI relevant argument attribute in indirect calls.
Add pretty printer/parser for fir.call argument/result attributes and propagate them to llvm.call. This will allow implementing the TODO about ABI relevant argument attribute in indirect calls.
Add pretty printer/parser for fir.call argument/result attributes and propagate them to llvm.call. This will allow implementing the TODO about ABI relevant argument attribute in indirect calls.
Add pretty printer/parser for fir.call argument/result attributes and propagate them to llvm.call.
This will allow implementing the TODO about ABI relevant argument attribute in indirect calls here.