diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index cb4eb8303a495..f938d8d377465 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -589,10 +589,14 @@ struct CallOpConversion : public fir::FIROpConversion { // Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr. mlir::arith::AttrConvertFastMathToLLVM attrConvert(call); - rewriter.replaceOpWithNewOp( + auto llvmCall = rewriter.replaceOpWithNewOp( call, resultTys, adaptor.getOperands(), addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(), adaptor.getOperands().size())); + if (mlir::ArrayAttr argAttrs = call.getArgAttrsAttr()) + llvmCall.setArgAttrsAttr(argAttrs); + if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr()) + llvmCall.setResAttrsAttr(resAttrs); return mlir::success(); } }; diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index fa83aa380e489..7e50622db08c9 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -1121,11 +1121,12 @@ void fir::CallOp::print(mlir::OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs(), {fir::CallOp::getCalleeAttrNameStr(), - getFastmathAttrName(), getProcedureAttrsAttrName()}); - auto resultTypes{getResultTypes()}; - llvm::SmallVector argTypes( - llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1)); - p << " : " << mlir::FunctionType::get(getContext(), argTypes, resultTypes); + getFastmathAttrName(), getProcedureAttrsAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()}); + p << " : "; + mlir::call_interface_impl::printFunctionSignature( + p, getArgs().drop_front(isDirect ? 0 : 1).getTypes(), getArgAttrsAttr(), + /*isVariadic=*/false, getResultTypes(), getResAttrsAttr()); } mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser, @@ -1142,7 +1143,6 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser, attrs)) return mlir::failure(); - mlir::Type type; if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren)) return mlir::failure(); @@ -1163,13 +1163,17 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser, fmfAttrName, attrs)) return mlir::failure(); - if (parser.parseOptionalAttrDict(attrs) || parser.parseColon() || - parser.parseType(type)) + if (parser.parseOptionalAttrDict(attrs) || parser.parseColon()) return mlir::failure(); - - auto funcType = mlir::dyn_cast(type); - if (!funcType) + llvm::SmallVector argTypes; + llvm::SmallVector resTypes; + llvm::SmallVector argAttrs; + llvm::SmallVector resultAttrs; + if (mlir::call_interface_impl::parseFunctionSignature( + parser, argTypes, argAttrs, resTypes, resultAttrs)) return parser.emitError(parser.getNameLoc(), "expected function type"); + mlir::FunctionType funcType = + mlir::FunctionType::get(parser.getContext(), argTypes, resTypes); if (isDirect) { if (parser.resolveOperands(operands, funcType.getInputs(), parser.getNameLoc(), result.operands)) @@ -1183,8 +1187,11 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser, parser.getNameLoc(), result.operands)) return mlir::failure(); } - result.addTypes(funcType.getResults()); result.attributes = attrs; + mlir::call_interface_impl::addArgAndResultAttrs( + parser.getBuilder(), result, argAttrs, resultAttrs, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); + result.addTypes(funcType.getResults()); return mlir::success(); } diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir index 6d7a4a09918e5..c11cfd5d5faa1 100644 --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -2853,3 +2853,21 @@ gpu.module @cuda_device_mod { // CHECK: llvm.func @malloc(i64) -> !llvm.ptr // CHECK: llvm.call @malloc // CHECK: lvm.call @free + +// ----- + +func.func private @somefunc(i32, !fir.ref) + +// CHECK-LABEL: @test_call_arg_attrs_direct +func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref) { + // CHECK: llvm.call @somefunc(%{{.*}}, %{{.*}}) : (i32, !llvm.ptr {llvm.byval = i64}) -> () + fir.call @somefunc(%arg0, %arg1) : (i32, !fir.ref {llvm.byval = i64}) -> () + return +} + +// CHECK-LABEL: @test_call_arg_attrs_indirect +func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 { + // CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + %0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + return %0 : i16 +} diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir index 5a30858511f0c..1bfcb3a9f3dc8 100644 --- a/flang/test/Fir/fir-ops.fir +++ b/flang/test/Fir/fir-ops.fir @@ -913,3 +913,23 @@ func.func @test_is_assumed_size(%arg0: !fir.class>, %arg1 : ! // CHECK-SAME: %[[B:.*]]: !fir.box>) // CHECK: fir.is_assumed_size %[[A]] : (!fir.class>) -> i1 // CHECK: fir.is_assumed_size %[[B]] : (!fir.box>) -> i1 + +func.func private @somefunc(i32, !fir.ref) + +// CHECK-LABEL: @test_call_arg_attrs_direct +// CHECK-SAME: %[[VAL_0:.*]]: i32, +// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref) { +func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref) { + // CHECK: fir.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !fir.ref {llvm.byval = i64}) -> () + fir.call @somefunc(%arg0, %arg1) : (i32, !fir.ref {llvm.byval = i64}) -> () + return +} + +// CHECK-LABEL: @test_call_arg_attrs_indirect +// CHECK-SAME: %[[VAL_0:.*]]: i16, +// CHECK-SAME: %[[VAL_1:.*]]: (i16) -> i16) -> i16 { +func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 { + // CHECK: fir.call %[[VAL_1]](%[[VAL_0]]) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + %0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + return %0 : i16 +}