From f062d7811e81a4be97811d6addd53187cf11e96d Mon Sep 17 00:00:00 2001 From: Jean Perier Date: Mon, 3 Feb 2025 08:31:01 -0800 Subject: [PATCH 1/5] [mlir][LLVM] add argument and result attributes to llvm.call --- llvm/include/llvm/IR/InstrTypes.h | 11 ++++ .../include/mlir/Target/LLVMIR/ModuleImport.h | 8 ++- .../mlir/Target/LLVMIR/ModuleTranslation.h | 9 ++- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 64 ++++++++++++------- .../LLVMIR/LLVMToLLVMIRTranslation.cpp | 21 ++++++ mlir/lib/Target/LLVMIR/ModuleImport.cpp | 34 ++++++++++ mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 52 +++++++++++---- mlir/test/Dialect/LLVMIR/invalid.mlir | 2 + mlir/test/Dialect/LLVMIR/roundtrip.mlir | 20 ++++++ .../LLVMIR/Import/call-argument-attributes.ll | 22 +++++++ .../LLVMIR/call-argument-attributes.mlir | 17 +++++ 11 files changed, 220 insertions(+), 40 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll create mode 100644 mlir/test/Target/LLVMIR/call-argument-attributes.mlir diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h index 26be02d4b193d..90fe864d4ae71 100644 --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -1490,6 +1490,11 @@ class CallBase : public Instruction { Attrs = Attrs.addRetAttribute(getContext(), Attr); } + /// Adds attributes to the return value. + void addRetAttrs(const AttrBuilder &B) { + Attrs = Attrs.addRetAttributes(getContext(), B); + } + /// Adds the attribute to the indicated argument void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) { assert(ArgNo < arg_size() && "Out of bounds"); @@ -1502,6 +1507,12 @@ class CallBase : public Instruction { Attrs = Attrs.addParamAttribute(getContext(), ArgNo, Attr); } + /// Adds attributes to the indicated argument + void addParamAttrs(unsigned ArgNo, const AttrBuilder &B) { + assert(ArgNo < arg_size() && "Out of bounds"); + Attrs = Attrs.addParamAttributes(getContext(), ArgNo, B); + } + /// removes the attribute from the list of attributes. void removeAttributeAtIndex(unsigned i, Attribute::AttrKind Kind) { Attrs = Attrs.removeAttributeAtIndex(getContext(), i, Kind); diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index 80ae4d679624c..d09c73c2f467d 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -335,14 +335,18 @@ class ModuleImport { FailureOr convertFunctionType(llvm::CallBase *callInst); /// Returns the callee name, or an empty symbol if the call is not direct. FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst); - /// Converts the parameter attributes attached to `func` and adds them to - /// the `funcOp`. + /// Converts the parameter and result attributes attached to `func` and adds + /// them to the `funcOp`. void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp, OpBuilder &builder); /// Converts the AttributeSet of one parameter in LLVM IR to a corresponding /// DictionaryAttr for the LLVM dialect. DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs, OpBuilder &builder); + /// Converts the parameter and result attributes attached to `call` and adds + /// them to the `callOp`. + void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp, + OpBuilder &builder); /// Converts the attributes attached to `inst` and adds them to the `op`. LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op); /// Converts the attributes attached to `inst` and adds them to the `op`. diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 1b62437761ed9..88fc17ca4fda2 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -228,6 +228,11 @@ class ModuleTranslation { /*recordInsertions=*/false); } + /// Translates parameter attributes of a call and adds them to the returned + /// AttrBuilder. Returns failure if any of the translations failed. + FailureOr convertParameterAttrs(CallOp callOp, int argIdx, + DictionaryAttr paramAttrs); + /// Gets the named metadata in the LLVM IR module being constructed, creating /// it if it does not exist. llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name); @@ -346,8 +351,8 @@ class ModuleTranslation { convertDialectAttributes(Operation *op, ArrayRef instructions); - /// Translates parameter attributes and adds them to the returned AttrBuilder. - /// Returns failure if any of the translations failed. + /// Translates parameter attributes of a function and adds them to the + /// returned AttrBuilder. Returns failure if any of the translations failed. FailureOr convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index a6e996f3fb810..25d45f70b09ac 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1335,42 +1335,53 @@ void CallOp::print(OpAsmPrinter &p) { getVarCalleeTypeAttrName(), getCConvAttrName(), getOperandSegmentSizesAttrName(), getOpBundleSizesAttrName(), - getOpBundleTagsAttrName()}); + getOpBundleTagsAttrName(), getArgAttrsAttrName(), + getResAttrsAttrName()}); p << " : "; if (!isDirect) p << getOperand(0).getType() << ", "; // Reconstruct the function MLIR function type from operand and result types. - p.printFunctionalType(args.getTypes(), getResultTypes()); + call_interface_impl::printFunctionSignature( + p, args.getTypes(), getArgAttrsAttr(), + /*isVariadic=*/false, getResultTypes(), getResAttrsAttr()); } /// Parses the type of a call operation and resolves the operands if the parsing /// succeeds. Returns failure otherwise. static ParseResult parseCallTypeAndResolveOperands( OpAsmParser &parser, OperationState &result, bool isDirect, - ArrayRef operands) { + ArrayRef operands, + SmallVectorImpl &argAttrs, + SmallVectorImpl &resultAttrs) { SMLoc trailingTypesLoc = parser.getCurrentLocation(); SmallVector types; - if (parser.parseColonTypeList(types)) + if (parser.parseColon()) return failure(); - - if (isDirect && types.size() != 1) - return parser.emitError(trailingTypesLoc, - "expected direct call to have 1 trailing type"); - if (!isDirect && types.size() != 2) - return parser.emitError(trailingTypesLoc, - "expected indirect call to have 2 trailing types"); - - auto funcType = llvm::dyn_cast(types.pop_back_val()); - if (!funcType) + if (!isDirect) { + types.emplace_back(); + if (parser.parseType(types.back())) + return failure(); + if (parser.parseOptionalComma()) + return parser.emitError( + trailingTypesLoc, "expected indirect call to have 2 trailing types"); + } + SmallVector argTypes; + SmallVector resTypes; + if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs, + resTypes, resultAttrs)) { + if (isDirect) + return parser.emitError(trailingTypesLoc, + "expected direct call to have 1 trailing types"); return parser.emitError(trailingTypesLoc, "expected trailing function type"); - if (funcType.getNumResults() > 1) + } + + if (resTypes.size() > 1) return parser.emitError(trailingTypesLoc, "expected function with 0 or 1 result"); - if (funcType.getNumResults() == 1 && - llvm::isa(funcType.getResult(0))) + if (resTypes.size() == 1 && llvm::isa(resTypes[0])) return parser.emitError(trailingTypesLoc, "expected a non-void result type"); @@ -1378,12 +1389,12 @@ static ParseResult parseCallTypeAndResolveOperands( // indirect calls, while the types list is emtpy for direct calls. // Append the function input types to resolve the call operation // operands. - llvm::append_range(types, funcType.getInputs()); + llvm::append_range(types, argTypes); if (parser.resolveOperands(operands, types, parser.getNameLoc(), result.operands)) return failure(); - if (funcType.getNumResults() != 0) - result.addTypes(funcType.getResults()); + if (resTypes.size() != 0) + result.addTypes(resTypes); return success(); } @@ -1497,8 +1508,14 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); // Parse the trailing type list and resolve the operands. - if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands)) + SmallVector argAttrs; + SmallVector resultAttrs; + if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands, + argAttrs, resultAttrs)) return failure(); + call_interface_impl::addArgAndResultAttrs( + parser.getBuilder(), result, argAttrs, resultAttrs, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands, opBundleOperandTypes, getOpBundleSizesAttrName(result.name))) @@ -1721,7 +1738,10 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); // Parse the trailing type list and resolve the function operands. - if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands)) + SmallVector argAttrs; + SmallVector resultAttrs; + if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands, + argAttrs, resultAttrs)) return failure(); if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands, opBundleOperandTypes, diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 2084e527773ca..52f42df60f001 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -265,6 +265,27 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, if (callOp.getWillReturnAttr()) call->addFnAttr(llvm::Attribute::WillReturn); + if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) + for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) { + if (auto argAttrs = llvm::cast(argAttrsAttr)) { + FailureOr attrBuilder = + moduleTranslation.convertParameterAttrs(callOp, argIdx, argAttrs); + if (failed(attrBuilder)) + return failure(); + call->addParamAttrs(argIdx, *attrBuilder); + } + } + + ArrayAttr resAttrsArray = callOp.getResAttrsAttr(); + if (resAttrsArray && resAttrsArray.size() == 1) + if (auto resAttrs = llvm::cast(resAttrsArray[0])) { + FailureOr attrBuilder = + moduleTranslation.convertParameterAttrs(callOp, -1, resAttrs); + if (failed(attrBuilder)) + return failure(); + call->addRetAttrs(*attrBuilder); + } + if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) { llvm::MemoryEffects memEffects = llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem, diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 5ebde22cccbdf..8d779c5083eb6 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1706,6 +1706,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { auto callOp = builder.create(loc, *funcTy, callee, *operands); if (failed(convertCallAttributes(callInst, callOp))) return failure(); + // Handle parameter and result attributes. + convertParameterAttributes(callInst, callOp, builder); return callOp.getOperation(); }(); @@ -2149,6 +2151,38 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func, builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder))); } +void ModuleImport::convertParameterAttributes(llvm::CallBase *call, + CallOpInterface callOp, + OpBuilder &builder) { + auto llvmAttrs = call->getAttributes(); + SmallVector llvmArgAttrsSet; + bool anyArgAttrs = false; + for (size_t i = 0, e = call->arg_size(); i < e; ++i) { + llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i)); + if (llvmArgAttrsSet.back().hasAttributes()) + anyArgAttrs = true; + } + auto getArrayAttr = [&](ArrayRef dictAttrs) { + SmallVector attrs; + for (auto &dict : dictAttrs) + attrs.push_back(dict ? dict : builder.getDictionaryAttr({})); + return builder.getArrayAttr(attrs); + }; + if (anyArgAttrs) { + SmallVector argAttrs; + for (auto &llvmArgAttrs : llvmArgAttrsSet) + argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder)); + callOp.setArgAttrsAttr(getArrayAttr(argAttrs)); + } + + llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs(); + if (!llvmResAttr.hasAttributes()) + return; + SmallVector resAttrs; + resAttrs.emplace_back(convertParameterAttribute(llvmResAttr, builder)); + callOp.setResAttrsAttr(getArrayAttr(resAttrs)); +} + template static LogicalResult convertCallBaseAttributes(llvm::CallBase *inst, Op op) { op.setCConv(convertCConvFromLLVM(inst->getCallingConv())); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 4367100e3aca6..b2d2c1cddca31 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1563,6 +1563,26 @@ static void convertFunctionKernelAttributes(LLVMFuncOp func, } } +static void convertParameterAttr(llvm::AttrBuilder &attrBuilder, + llvm::Attribute::AttrKind llvmKind, + NamedAttribute namedAttr, + ModuleTranslation &moduleTranslation) { + llvm::TypeSwitch(namedAttr.getValue()) + .Case([&](auto typeAttr) { + attrBuilder.addTypeAttr( + llvmKind, moduleTranslation.convertType(typeAttr.getValue())); + }) + .Case([&](auto intAttr) { + attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt()); + }) + .Case([&](auto) { attrBuilder.addAttribute(llvmKind); }) + .Case([&](auto rangeAttr) { + attrBuilder.addConstantRangeAttr( + llvmKind, + llvm::ConstantRange(rangeAttr.getLower(), rangeAttr.getUpper())); + }); +} + FailureOr ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs) { @@ -1573,20 +1593,7 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, auto it = attrNameToKindMapping.find(namedAttr.getName()); if (it != attrNameToKindMapping.end()) { llvm::Attribute::AttrKind llvmKind = it->second; - - llvm::TypeSwitch(namedAttr.getValue()) - .Case([&](auto typeAttr) { - attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue())); - }) - .Case([&](auto intAttr) { - attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt()); - }) - .Case([&](auto) { attrBuilder.addAttribute(llvmKind); }) - .Case([&](auto rangeAttr) { - attrBuilder.addConstantRangeAttr( - llvmKind, llvm::ConstantRange(rangeAttr.getLower(), - rangeAttr.getUpper())); - }); + convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this); } else if (namedAttr.getNameDialect()) { if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this))) return failure(); @@ -1596,6 +1603,23 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, return attrBuilder; } +FailureOr +ModuleTranslation::convertParameterAttrs(CallOp, int argIdx, + DictionaryAttr paramAttrs) { + llvm::AttrBuilder attrBuilder(llvmModule->getContext()); + auto attrNameToKindMapping = getAttrNameToKindMapping(); + + for (auto namedAttr : paramAttrs) { + auto it = attrNameToKindMapping.find(namedAttr.getName()); + if (it != attrNameToKindMapping.end()) { + llvm::Attribute::AttrKind llvmKind = it->second; + convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this); + } + } + + return attrBuilder; +} + LogicalResult ModuleTranslation::convertFunctionSignatures() { // Declare all functions first because there may be function calls that form a // call graph with cycles, or global initializers that reference functions. diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 5c939318fe3ed..76c57e76f8493 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -235,6 +235,7 @@ func.func @call_missing_ptr_type(%callee : !llvm.func, %arg : i8) { func.func private @standard_func_callee() func.func @call_missing_ptr_type(%arg : i8) { + // expected-error@+2 {{expected '('}} // expected-error@+1 {{expected direct call to have 1 trailing type}} llvm.call @standard_func_callee(%arg) : !llvm.ptr, (i8) -> (i8) llvm.return @@ -251,6 +252,7 @@ func.func @call_non_pointer_type(%callee : !llvm.func, %arg : i8) { // ----- func.func @call_non_function_type(%callee : !llvm.ptr, %arg : i8) { + // expected-error@+2 {{expected '('}} // expected-error@+1 {{expected trailing function type}} llvm.call %callee(%arg) : !llvm.ptr, !llvm.func llvm.return diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 88660ce598f3c..e565772f06b03 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -941,3 +941,23 @@ llvm.func @test_assume_intr_with_opbundles(%arg0 : !llvm.ptr) { llvm.intr.assume %0 ["tag1"(%1, %2 : i32, i32), "tag2"(%3 : i32)] : i1 llvm.return } + +llvm.func @somefunc(i32, !llvm.ptr) + +// CHECK-LABEL: llvm.func @test_call_arg_attrs_direct( +// CHECK-SAME: %[[VAL_0:.*]]: i32, +// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) +llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) { + // CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> () + llvm.call @somefunc(%arg0, %arg1) : (i32, !llvm.ptr {llvm.byval = i64}) -> () + llvm.return +} + +// CHECK-LABEL: llvm.func @test_call_arg_attrs_indirect( +// CHECK-SAME: %[[VAL_0:.*]]: i16, +// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr +llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 { + // CHECK: llvm.call tail %[[VAL_1]](%[[VAL_0]]) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + %0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + llvm.return %0 : i16 +} diff --git a/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll new file mode 100644 index 0000000000000..2c86ca6b03125 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll @@ -0,0 +1,22 @@ +; RUN: mlir-translate -import-llvm %s | FileCheck %s + +; CHECK-LABEL: llvm.func @somefunc(i32, !llvm.ptr) +declare void @somefunc(i32, ptr) + +; CHECK-LABEL: llvm.func @test_call_arg_attrs_direct( +; CHECK-SAME: %[[VAL_0:.*]]: i32, +; CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) +define void @test_call_arg_attrs_direct(i32 %0, ptr %1) { + ; CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> () + call void @somefunc(i32 %0, ptr byval(i64) %1) + ret void +} + +; CHECK-LABEL: llvm.func @test_call_arg_attrs_indirect( +; CHECK-SAME: %[[VAL_0:.*]]: i16, +; CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr +define i16 @test_call_arg_attrs_indirect(i16 %0, ptr %1) { +; CHECK: llvm.call tail %[[VAL_1]](%[[VAL_0]]) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + %3 = tail call signext i16 %1(i16 noundef signext %0) + ret i16 %3 +} diff --git a/mlir/test/Target/LLVMIR/call-argument-attributes.mlir b/mlir/test/Target/LLVMIR/call-argument-attributes.mlir new file mode 100644 index 0000000000000..89b1f29a68623 --- /dev/null +++ b/mlir/test/Target/LLVMIR/call-argument-attributes.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @somefunc(i32, !llvm.ptr) + +// CHECK-LABEL: define void @test_call_arg_attrs_direct +llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) { + // CHECK: call void @somefunc(i32 %{{.*}}, ptr byval(i64) %{{.*}}) + llvm.call @somefunc(%arg0, %arg1) : (i32, !llvm.ptr {llvm.byval = i64}) -> () + llvm.return +} + +// CHECK-LABEL: define i16 @test_call_arg_attrs_indirec +llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 { + // CHECK: tail call signext i16 %{{.*}}(i16 noundef signext %{{.*}}) + %0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + llvm.return %0 : i16 +} From b9a67a93fc4e7a292e4be66d3eb3a5393e9da412 Mon Sep 17 00:00:00 2001 From: Jean Perier Date: Mon, 3 Feb 2025 07:20:43 -0800 Subject: [PATCH 2/5] remove argIdx arg + style nits --- mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h | 2 +- .../LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp | 10 ++++++---- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 3 +-- mlir/test/Target/LLVMIR/call-argument-attributes.mlir | 2 +- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 88fc17ca4fda2..25f17ba4f6a35 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -230,7 +230,7 @@ class ModuleTranslation { /// Translates parameter attributes of a call and adds them to the returned /// AttrBuilder. Returns failure if any of the translations failed. - FailureOr convertParameterAttrs(CallOp callOp, int argIdx, + FailureOr convertParameterAttrs(CallOp callOp, DictionaryAttr paramAttrs); /// Gets the named metadata in the LLVM IR module being constructed, creating diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 52f42df60f001..822d6d7bb467b 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -265,26 +265,28 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, if (callOp.getWillReturnAttr()) call->addFnAttr(llvm::Attribute::WillReturn); - if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) + if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) { for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) { if (auto argAttrs = llvm::cast(argAttrsAttr)) { FailureOr attrBuilder = - moduleTranslation.convertParameterAttrs(callOp, argIdx, argAttrs); + moduleTranslation.convertParameterAttrs(callOp, argAttrs); if (failed(attrBuilder)) return failure(); call->addParamAttrs(argIdx, *attrBuilder); } } + } ArrayAttr resAttrsArray = callOp.getResAttrsAttr(); - if (resAttrsArray && resAttrsArray.size() == 1) + if (resAttrsArray && resAttrsArray.size() == 1) { if (auto resAttrs = llvm::cast(resAttrsArray[0])) { FailureOr attrBuilder = - moduleTranslation.convertParameterAttrs(callOp, -1, resAttrs); + moduleTranslation.convertParameterAttrs(callOp, resAttrs); if (failed(attrBuilder)) return failure(); call->addRetAttrs(*attrBuilder); } + } if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) { llvm::MemoryEffects memEffects = diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index b2d2c1cddca31..5cee1fc5c1cf4 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1604,8 +1604,7 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, } FailureOr -ModuleTranslation::convertParameterAttrs(CallOp, int argIdx, - DictionaryAttr paramAttrs) { +ModuleTranslation::convertParameterAttrs(CallOp, DictionaryAttr paramAttrs) { llvm::AttrBuilder attrBuilder(llvmModule->getContext()); auto attrNameToKindMapping = getAttrNameToKindMapping(); diff --git a/mlir/test/Target/LLVMIR/call-argument-attributes.mlir b/mlir/test/Target/LLVMIR/call-argument-attributes.mlir index 89b1f29a68623..b3d286dcda504 100644 --- a/mlir/test/Target/LLVMIR/call-argument-attributes.mlir +++ b/mlir/test/Target/LLVMIR/call-argument-attributes.mlir @@ -9,7 +9,7 @@ llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) { llvm.return } -// CHECK-LABEL: define i16 @test_call_arg_attrs_indirec +// CHECK-LABEL: define i16 @test_call_arg_attrs_indirect llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 { // CHECK: tail call signext i16 %{{.*}}(i16 noundef signext %{{.*}}) %0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) From 25a74a909bce87e19a0a842572e3865894f96f43 Mon Sep 17 00:00:00 2001 From: Jean Perier Date: Thu, 6 Feb 2025 01:17:06 -0800 Subject: [PATCH 3/5] handle invoke parameter attributes in import, export, pretty printing --- .../mlir/Target/LLVMIR/ModuleTranslation.h | 3 +- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 17 ++++-- .../LLVMIR/LLVMToLLVMIRTranslation.cpp | 55 +++++++++++-------- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 3 + mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 2 +- mlir/test/Dialect/LLVMIR/roundtrip.mlir | 25 +++++++++ .../Import/invoke-argument-attributes.ll | 26 +++++++++ .../LLVMIR/invoke-argument-attributes.mlir | 25 +++++++++ 8 files changed, 126 insertions(+), 30 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll create mode 100644 mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 25f17ba4f6a35..3ad7f1e33f0a3 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -230,8 +230,7 @@ class ModuleTranslation { /// Translates parameter attributes of a call and adds them to the returned /// AttrBuilder. Returns failure if any of the translations failed. - FailureOr convertParameterAttrs(CallOp callOp, - DictionaryAttr paramAttrs); + FailureOr convertParameterAttrs(DictionaryAttr paramAttrs); /// Gets the named metadata in the LLVM IR module being constructed, creating /// it if it does not exist. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 25d45f70b09ac..bea90acc9364f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1660,14 +1660,16 @@ void InvokeOp::print(OpAsmPrinter &p) { {getCalleeAttrName(), getOperandSegmentSizeAttr(), getCConvAttrName(), getVarCalleeTypeAttrName(), getOpBundleSizesAttrName(), - getOpBundleTagsAttrName()}); + getOpBundleTagsAttrName(), getArgAttrsAttrName(), + getResAttrsAttrName()}); p << " : "; if (!isDirect) p << getOperand(0).getType() << ", "; - p.printFunctionalType( - llvm::drop_begin(getCalleeOperands().getTypes(), isDirect ? 0 : 1), - getResultTypes()); + call_interface_impl::printFunctionSignature( + p, getCalleeOperands().drop_front(isDirect ? 0 : 1).getTypes(), + getArgAttrsAttr(), + /*isVariadic=*/false, getResultTypes(), getResAttrsAttr()); } // ::= `llvm.invoke` (cconv)? (function-id | ssa-use) @@ -1676,7 +1678,8 @@ void InvokeOp::print(OpAsmPrinter &p) { // `unwind` bb-id (`[` ssa-use-and-type-list `]`)? // ( `vararg(` var-callee-type `)` )? // ( `[` op-bundles-list `]` )? -// attribute-dict? `:` (type `,`)? function-type +// attribute-dict? `:` (type `,`)? +// function-type-with-argument-attributes ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector operands; SymbolRefAttr funcAttr; @@ -1743,6 +1746,10 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands, argAttrs, resultAttrs)) return failure(); + call_interface_impl::addArgAndResultAttrs( + parser.getBuilder(), result, argAttrs, resultAttrs, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); + if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands, opBundleOperandTypes, getOpBundleSizesAttrName(result.name))) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 822d6d7bb467b..fc295ab7e8e1d 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -224,6 +224,34 @@ static void convertLinkerOptionsOp(ArrayAttr options, linkerMDNode->addOperand(listMDNode); } +static LogicalResult +convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call, + LLVM::ModuleTranslation &moduleTranslation) { + if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) { + for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) { + if (auto argAttrs = llvm::cast(argAttrsAttr)) { + FailureOr attrBuilder = + moduleTranslation.convertParameterAttrs(argAttrs); + if (failed(attrBuilder)) + return failure(); + call->addParamAttrs(argIdx, *attrBuilder); + } + } + } + + ArrayAttr resAttrsArray = callOp.getResAttrsAttr(); + if (resAttrsArray && resAttrsArray.size() == 1) { + if (auto resAttrs = llvm::cast(resAttrsArray[0])) { + FailureOr attrBuilder = + moduleTranslation.convertParameterAttrs(resAttrs); + if (failed(attrBuilder)) + return failure(); + call->addRetAttrs(*attrBuilder); + } + } + return success(); +} + static LogicalResult convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -265,28 +293,8 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, if (callOp.getWillReturnAttr()) call->addFnAttr(llvm::Attribute::WillReturn); - if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) { - for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) { - if (auto argAttrs = llvm::cast(argAttrsAttr)) { - FailureOr attrBuilder = - moduleTranslation.convertParameterAttrs(callOp, argAttrs); - if (failed(attrBuilder)) - return failure(); - call->addParamAttrs(argIdx, *attrBuilder); - } - } - } - - ArrayAttr resAttrsArray = callOp.getResAttrsAttr(); - if (resAttrsArray && resAttrsArray.size() == 1) { - if (auto resAttrs = llvm::cast(resAttrsArray[0])) { - FailureOr attrBuilder = - moduleTranslation.convertParameterAttrs(callOp, resAttrs); - if (failed(attrBuilder)) - return failure(); - call->addRetAttrs(*attrBuilder); - } - } + if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation))) + return failure(); if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) { llvm::MemoryEffects memEffects = @@ -395,6 +403,9 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, operandsRef.drop_front(), opBundles); } result->setCallingConv(convertCConvToLLVM(invOp.getCConv())); + if (failed( + convertParameterAndResultAttrs(invOp, result, moduleTranslation))) + return failure(); moduleTranslation.mapBranch(invOp, result); // InvokeOp can only have 0 or 1 result if (invOp->getNumResults() != 0) { diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 8d779c5083eb6..a55a65f9067ad 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1788,6 +1788,9 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { if (failed(convertInvokeAttributes(invokeInst, invokeOp))) return failure(); + // Handle parameter and result attributes. + convertParameterAttributes(invokeInst, invokeOp, builder); + if (!invokeInst->getType()->isVoidTy()) mapValue(inst, invokeOp.getResults().front()); else diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 5cee1fc5c1cf4..21e15b5dbf96d 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1604,7 +1604,7 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, } FailureOr -ModuleTranslation::convertParameterAttrs(CallOp, DictionaryAttr paramAttrs) { +ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) { llvm::AttrBuilder attrBuilder(llvmModule->getContext()); auto attrNameToKindMapping = getAttrNameToKindMapping(); diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index e565772f06b03..09a0cd57e2675 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -961,3 +961,28 @@ llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 { %0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) llvm.return %0 : i16 } + +// CHECK-LABEL: llvm.func @test_invoke_arg_attrs( +// CHECK-SAME: %[[VAL_0:.*]]: i16) attributes {personality = @__gxx_personality_v0} { +llvm.func @test_invoke_arg_attrs(%arg0: i16) attributes { personality = @__gxx_personality_v0 } { + // CHECK: llvm.invoke @somefunc(%[[VAL_0]]) to ^bb2 unwind ^bb1 : (i16 {llvm.noundef, llvm.signext}) -> () + llvm.invoke @somefunc(%arg0) to ^bb2 unwind ^bb1 : (i16 {llvm.noundef, llvm.signext}) -> () +^bb1: + %1 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)> + llvm.return +^bb2: + llvm.return +} + +// CHECK-LABEL: llvm.func @test_invoke_arg_attrs_indirect( +// CHECK-SAME: %[[VAL_0:.*]]: i16, +// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) -> i16 attributes {personality = @__gxx_personality_v0} { +llvm.func @test_invoke_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 attributes { personality = @__gxx_personality_v0 } { + // CHECK: llvm.invoke %[[VAL_1]](%[[VAL_0]]) to ^bb2 unwind ^bb1 : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + %0 = llvm.invoke %arg1(%arg0) to ^bb2 unwind ^bb1 : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) +^bb1: + %1 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)> + llvm.return %0 : i16 +^bb2: + llvm.return %0 : i16 +} diff --git a/mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll b/mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll new file mode 100644 index 0000000000000..e606961e1a252 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll @@ -0,0 +1,26 @@ +; RUN: mlir-translate -import-llvm %s | FileCheck %s + +; CHECK-LABEL: llvm.func @test( +; CHECK-SAME: %[[VAL_0:.*]]: i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) attributes {personality = @__gxx_personality_v0} { +define signext i16 @test(i16 noundef signext %0) personality ptr @__gxx_personality_v0 { +; CHECK: %[[VAL_3:.*]] = llvm.invoke @somefunc(%[[VAL_0]]) to ^bb2 unwind ^bb1 : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + %2 = invoke signext i16 @somefunc(i16 noundef signext %0) + to label %7 unwind label %3 + +3: ; preds = %1 + %4 = landingpad { ptr, i32 } + catch ptr null + %5 = extractvalue { ptr, i32 } %4, 0 + %6 = tail call ptr @__cxa_begin_catch(ptr %5) #2 + tail call void @__cxa_end_catch() + br label %7 + +7: ; preds = %1, %3 + %8 = phi i16 [ 0, %3 ], [ %2, %1 ] + ret i16 %8 +} + +declare noundef signext i16 @somefunc(i16 noundef signext) +declare i32 @__gxx_personality_v0(...) +declare ptr @__cxa_begin_catch(ptr) +declare void @__cxa_end_catch() diff --git a/mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir b/mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir new file mode 100644 index 0000000000000..5d6e49bfe09e8 --- /dev/null +++ b/mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @test +llvm.func @test(%arg0: i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) attributes {personality = @__gxx_personality_v0} { + %0 = llvm.mlir.zero : !llvm.ptr + %1 = llvm.mlir.constant(0 : i16) : i16 +// CHECK: invoke signext i16 @somefunc(i16 noundef signext %{{.*}}) +// CHECK-NEXT: to label %{{.*}} unwind label %{{.*}} + %2 = llvm.invoke @somefunc(%arg0) to ^bb2 unwind ^bb1 : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) +^bb1: // pred: ^bb0 + %3 = llvm.landingpad (catch %0 : !llvm.ptr) : !llvm.struct<(ptr, i32)> + %4 = llvm.extractvalue %3[0] : !llvm.struct<(ptr, i32)> + %5 = llvm.call tail @__cxa_begin_catch(%4) : (!llvm.ptr) -> !llvm.ptr + llvm.call tail @__cxa_end_catch() : () -> () + llvm.br ^bb3(%1 : i16) +^bb2: // pred: ^bb0 + llvm.br ^bb3(%2 : i16) +^bb3(%6: i16): // 2 preds: ^bb1, ^bb2 + llvm.return %6 : i16 +} + +llvm.func @somefunc(i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.noundef, llvm.signext}) +llvm.func @__gxx_personality_v0(...) -> i32 +llvm.func @__cxa_begin_catch(!llvm.ptr) -> !llvm.ptr +llvm.func @__cxa_end_catch() From 00ecb241f6e47137185b7ea7c5c4ba42dee2df81 Mon Sep 17 00:00:00 2001 From: jeanPerier Date: Fri, 7 Feb 2025 14:13:18 +0100 Subject: [PATCH 4/5] Comment and test indentation update Co-authored-by: Tobias Gysi --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 2 +- mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll | 2 +- mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll | 2 +- mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index bea90acc9364f..0e31804959c22 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1342,7 +1342,7 @@ void CallOp::print(OpAsmPrinter &p) { if (!isDirect) p << getOperand(0).getType() << ", "; - // Reconstruct the function MLIR function type from operand and result types. + // Reconstruct the MLIR function type from operand and result types. call_interface_impl::printFunctionSignature( p, args.getTypes(), getArgAttrsAttr(), /*isVariadic=*/false, getResultTypes(), getResAttrsAttr()); diff --git a/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll index 2c86ca6b03125..fa39c79bf0859 100644 --- a/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll +++ b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll @@ -16,7 +16,7 @@ define void @test_call_arg_attrs_direct(i32 %0, ptr %1) { ; CHECK-SAME: %[[VAL_0:.*]]: i16, ; CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr define i16 @test_call_arg_attrs_indirect(i16 %0, ptr %1) { -; CHECK: llvm.call tail %[[VAL_1]](%[[VAL_0]]) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + ; CHECK: llvm.call tail %[[VAL_1]](%[[VAL_0]]) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) %3 = tail call signext i16 %1(i16 noundef signext %0) ret i16 %3 } diff --git a/mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll b/mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll index e606961e1a252..42489832fd184 100644 --- a/mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll +++ b/mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll @@ -3,7 +3,7 @@ ; CHECK-LABEL: llvm.func @test( ; CHECK-SAME: %[[VAL_0:.*]]: i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) attributes {personality = @__gxx_personality_v0} { define signext i16 @test(i16 noundef signext %0) personality ptr @__gxx_personality_v0 { -; CHECK: %[[VAL_3:.*]] = llvm.invoke @somefunc(%[[VAL_0]]) to ^bb2 unwind ^bb1 : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + ; CHECK: %[[VAL_3:.*]] = llvm.invoke @somefunc(%[[VAL_0]]) to ^bb2 unwind ^bb1 : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) %2 = invoke signext i16 @somefunc(i16 noundef signext %0) to label %7 unwind label %3 diff --git a/mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir b/mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir index 5d6e49bfe09e8..ea8ed1d416435 100644 --- a/mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir +++ b/mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir @@ -4,8 +4,8 @@ llvm.func @test(%arg0: i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) attributes {personality = @__gxx_personality_v0} { %0 = llvm.mlir.zero : !llvm.ptr %1 = llvm.mlir.constant(0 : i16) : i16 -// CHECK: invoke signext i16 @somefunc(i16 noundef signext %{{.*}}) -// CHECK-NEXT: to label %{{.*}} unwind label %{{.*}} + // CHECK: invoke signext i16 @somefunc(i16 noundef signext %{{.*}}) + // CHECK-NEXT: to label %{{.*}} unwind label %{{.*}} %2 = llvm.invoke @somefunc(%arg0) to ^bb2 unwind ^bb1 : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) ^bb1: // pred: ^bb0 %3 = llvm.landingpad (catch %0 : !llvm.ptr) : !llvm.struct<(ptr, i32)> From e4234d64bfa01abe10f0c01aadbae20eb96ed620 Mon Sep 17 00:00:00 2001 From: Jean Perier Date: Mon, 10 Feb 2025 06:36:48 -0800 Subject: [PATCH 5/5] better error reporting and nits --- .../mlir/Target/LLVMIR/ModuleTranslation.h | 3 +- .../LLVMIR/LLVMToLLVMIRTranslation.cpp | 15 +++++--- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 7 ++-- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 35 ++++++++++++++----- 4 files changed, 41 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 3ad7f1e33f0a3..52bdf601e9c7c 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -230,7 +230,8 @@ class ModuleTranslation { /// Translates parameter attributes of a call and adds them to the returned /// AttrBuilder. Returns failure if any of the translations failed. - FailureOr convertParameterAttrs(DictionaryAttr paramAttrs); + FailureOr convertParameterAttrs(CallOpInterface callOp, + DictionaryAttr paramAttrs); /// Gets the named metadata in the LLVM IR module being constructed, creating /// it if it does not exist. diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index fc295ab7e8e1d..ac0511eaaed32 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -229,9 +229,10 @@ convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call, LLVM::ModuleTranslation &moduleTranslation) { if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) { for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) { - if (auto argAttrs = llvm::cast(argAttrsAttr)) { + if (auto argAttrs = cast(argAttrsAttr); + !argAttrs.empty()) { FailureOr attrBuilder = - moduleTranslation.convertParameterAttrs(argAttrs); + moduleTranslation.convertParameterAttrs(callOp, argAttrs); if (failed(attrBuilder)) return failure(); call->addParamAttrs(argIdx, *attrBuilder); @@ -240,10 +241,14 @@ convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call, } ArrayAttr resAttrsArray = callOp.getResAttrsAttr(); - if (resAttrsArray && resAttrsArray.size() == 1) { - if (auto resAttrs = llvm::cast(resAttrsArray[0])) { + if (resAttrsArray && resAttrsArray.size() > 0) { + if (resAttrsArray.size() != 1) + return mlir::emitError(callOp.getLoc(), + "llvm.func cannot have multiple results"); + if (auto resAttrs = cast(resAttrsArray[0]); + !resAttrs.empty()) { FailureOr attrBuilder = - moduleTranslation.convertParameterAttrs(resAttrs); + moduleTranslation.convertParameterAttrs(callOp, resAttrs); if (failed(attrBuilder)) return failure(); call->addRetAttrs(*attrBuilder); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index a55a65f9067ad..f50a0870c0012 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -2157,7 +2157,7 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func, void ModuleImport::convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp, OpBuilder &builder) { - auto llvmAttrs = call->getAttributes(); + llvm::AttributeList llvmAttrs = call->getAttributes(); SmallVector llvmArgAttrsSet; bool anyArgAttrs = false; for (size_t i = 0, e = call->arg_size(); i < e; ++i) { @@ -2181,9 +2181,8 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call, llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs(); if (!llvmResAttr.hasAttributes()) return; - SmallVector resAttrs; - resAttrs.emplace_back(convertParameterAttribute(llvmResAttr, builder)); - callOp.setResAttrsAttr(getArrayAttr(resAttrs)); + DictionaryAttr resAttrs = convertParameterAttribute(llvmResAttr, builder); + callOp.setResAttrsAttr(getArrayAttr({resAttrs})); } template diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 21e15b5dbf96d..48d5d669aa0ae 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1563,23 +1563,33 @@ static void convertFunctionKernelAttributes(LLVMFuncOp func, } } -static void convertParameterAttr(llvm::AttrBuilder &attrBuilder, - llvm::Attribute::AttrKind llvmKind, - NamedAttribute namedAttr, - ModuleTranslation &moduleTranslation) { - llvm::TypeSwitch(namedAttr.getValue()) +static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder, + llvm::Attribute::AttrKind llvmKind, + NamedAttribute namedAttr, + ModuleTranslation &moduleTranslation, + Location loc) { + return llvm::TypeSwitch(namedAttr.getValue()) .Case([&](auto typeAttr) { attrBuilder.addTypeAttr( llvmKind, moduleTranslation.convertType(typeAttr.getValue())); + return success(); }) .Case([&](auto intAttr) { attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt()); + return success(); + }) + .Case([&](auto) { + attrBuilder.addAttribute(llvmKind); + return success(); }) - .Case([&](auto) { attrBuilder.addAttribute(llvmKind); }) .Case([&](auto rangeAttr) { attrBuilder.addConstantRangeAttr( llvmKind, llvm::ConstantRange(rangeAttr.getLower(), rangeAttr.getUpper())); + return success(); + }) + .Default([loc](auto) { + return emitError(loc, "unsupported parameter attribute type"); }); } @@ -1588,12 +1598,15 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs) { llvm::AttrBuilder attrBuilder(llvmModule->getContext()); auto attrNameToKindMapping = getAttrNameToKindMapping(); + Location loc = func.getLoc(); for (auto namedAttr : paramAttrs) { auto it = attrNameToKindMapping.find(namedAttr.getName()); if (it != attrNameToKindMapping.end()) { llvm::Attribute::AttrKind llvmKind = it->second; - convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this); + if (failed(convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this, + loc))) + return failure(); } else if (namedAttr.getNameDialect()) { if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this))) return failure(); @@ -1604,15 +1617,19 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, } FailureOr -ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) { +ModuleTranslation::convertParameterAttrs(CallOpInterface callOp, + DictionaryAttr paramAttrs) { llvm::AttrBuilder attrBuilder(llvmModule->getContext()); + Location loc = callOp.getLoc(); auto attrNameToKindMapping = getAttrNameToKindMapping(); for (auto namedAttr : paramAttrs) { auto it = attrNameToKindMapping.find(namedAttr.getName()); if (it != attrNameToKindMapping.end()) { llvm::Attribute::AttrKind llvmKind = it->second; - convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this); + if (failed(convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this, + loc))) + return failure(); } }