Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions llvm/include/llvm/IR/InstrTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,11 @@ class CallBase : public Instruction {
Attrs = Attrs.addRetAttribute(getContext(), Attr);
}

/// Adds attributes to the return value.
void addRetAttrs(const AttrBuilder &B) {
Attrs = Attrs.addRetAttributes(getContext(), B);
}

/// Adds the attribute to the indicated argument
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) {
assert(ArgNo < arg_size() && "Out of bounds");
Expand All @@ -1502,6 +1507,12 @@ class CallBase : public Instruction {
Attrs = Attrs.addParamAttribute(getContext(), ArgNo, Attr);
}

/// Adds attributes to the indicated argument
void addParamAttrs(unsigned ArgNo, const AttrBuilder &B) {
assert(ArgNo < arg_size() && "Out of bounds");
Attrs = Attrs.addParamAttributes(getContext(), ArgNo, B);
}

/// removes the attribute from the list of attributes.
void removeAttributeAtIndex(unsigned i, Attribute::AttrKind Kind) {
Attrs = Attrs.removeAttributeAtIndex(getContext(), i, Kind);
Expand Down
8 changes: 6 additions & 2 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,14 +335,18 @@ class ModuleImport {
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
/// Returns the callee name, or an empty symbol if the call is not direct.
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
/// Converts the parameter attributes attached to `func` and adds them to
/// the `funcOp`.
/// Converts the parameter and result attributes attached to `func` and adds
/// them to the `funcOp`.
void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
OpBuilder &builder);
/// Converts the AttributeSet of one parameter in LLVM IR to a corresponding
/// DictionaryAttr for the LLVM dialect.
DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
OpBuilder &builder);
/// Converts the parameter and result attributes attached to `call` and adds
/// them to the `callOp`.
void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp,
OpBuilder &builder);
/// Converts the attributes attached to `inst` and adds them to the `op`.
LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op);
/// Converts the attributes attached to `inst` and adds them to the `op`.
Expand Down
8 changes: 6 additions & 2 deletions mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ class ModuleTranslation {
/*recordInsertions=*/false);
}

/// Translates parameter attributes of a call and adds them to the returned
/// AttrBuilder. Returns failure if any of the translations failed.
FailureOr<llvm::AttrBuilder> convertParameterAttrs(DictionaryAttr paramAttrs);

/// Gets the named metadata in the LLVM IR module being constructed, creating
/// it if it does not exist.
llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
Expand Down Expand Up @@ -346,8 +350,8 @@ class ModuleTranslation {
convertDialectAttributes(Operation *op,
ArrayRef<llvm::Instruction *> instructions);

/// Translates parameter attributes and adds them to the returned AttrBuilder.
/// Returns failure if any of the translations failed.
/// Translates parameter attributes of a function and adds them to the
/// returned AttrBuilder. Returns failure if any of the translations failed.
FailureOr<llvm::AttrBuilder>
convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);

Expand Down
81 changes: 54 additions & 27 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,55 +1335,66 @@ void CallOp::print(OpAsmPrinter &p) {
getVarCalleeTypeAttrName(), getCConvAttrName(),
getOperandSegmentSizesAttrName(),
getOpBundleSizesAttrName(),
getOpBundleTagsAttrName()});
getOpBundleTagsAttrName(), getArgAttrsAttrName(),
getResAttrsAttrName()});

p << " : ";
if (!isDirect)
p << getOperand(0).getType() << ", ";

// Reconstruct the function MLIR function type from operand and result types.
p.printFunctionalType(args.getTypes(), getResultTypes());
call_interface_impl::printFunctionSignature(
p, args.getTypes(), getArgAttrsAttr(),
/*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
}

/// Parses the type of a call operation and resolves the operands if the parsing
/// succeeds. Returns failure otherwise.
static ParseResult parseCallTypeAndResolveOperands(
OpAsmParser &parser, OperationState &result, bool isDirect,
ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
ArrayRef<OpAsmParser::UnresolvedOperand> operands,
SmallVectorImpl<DictionaryAttr> &argAttrs,
SmallVectorImpl<DictionaryAttr> &resultAttrs) {
SMLoc trailingTypesLoc = parser.getCurrentLocation();
SmallVector<Type> types;
if (parser.parseColonTypeList(types))
if (parser.parseColon())
return failure();

if (isDirect && types.size() != 1)
return parser.emitError(trailingTypesLoc,
"expected direct call to have 1 trailing type");
if (!isDirect && types.size() != 2)
return parser.emitError(trailingTypesLoc,
"expected indirect call to have 2 trailing types");

auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val());
if (!funcType)
if (!isDirect) {
types.emplace_back();
if (parser.parseType(types.back()))
return failure();
if (parser.parseOptionalComma())
return parser.emitError(
trailingTypesLoc, "expected indirect call to have 2 trailing types");
}
SmallVector<Type> argTypes;
SmallVector<Type> resTypes;
if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
resTypes, resultAttrs)) {
if (isDirect)
return parser.emitError(trailingTypesLoc,
"expected direct call to have 1 trailing types");
return parser.emitError(trailingTypesLoc,
"expected trailing function type");
if (funcType.getNumResults() > 1)
}

if (resTypes.size() > 1)
return parser.emitError(trailingTypesLoc,
"expected function with 0 or 1 result");
if (funcType.getNumResults() == 1 &&
llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0)))
if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0]))
return parser.emitError(trailingTypesLoc,
"expected a non-void result type");

// The head element of the types list matches the callee type for
// indirect calls, while the types list is emtpy for direct calls.
// Append the function input types to resolve the call operation
// operands.
llvm::append_range(types, funcType.getInputs());
llvm::append_range(types, argTypes);
if (parser.resolveOperands(operands, types, parser.getNameLoc(),
result.operands))
return failure();
if (funcType.getNumResults() != 0)
result.addTypes(funcType.getResults());
if (resTypes.size() != 0)
result.addTypes(resTypes);

return success();
}
Expand Down Expand Up @@ -1497,8 +1508,14 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();

// Parse the trailing type list and resolve the operands.
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
SmallVector<DictionaryAttr> argAttrs;
SmallVector<DictionaryAttr> resultAttrs;
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
argAttrs, resultAttrs))
return failure();
call_interface_impl::addArgAndResultAttrs(
parser.getBuilder(), result, argAttrs, resultAttrs,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
opBundleOperandTypes,
getOpBundleSizesAttrName(result.name)))
Expand Down Expand Up @@ -1643,14 +1660,16 @@ void InvokeOp::print(OpAsmPrinter &p) {
{getCalleeAttrName(), getOperandSegmentSizeAttr(),
getCConvAttrName(), getVarCalleeTypeAttrName(),
getOpBundleSizesAttrName(),
getOpBundleTagsAttrName()});
getOpBundleTagsAttrName(), getArgAttrsAttrName(),
getResAttrsAttrName()});

p << " : ";
if (!isDirect)
p << getOperand(0).getType() << ", ";
p.printFunctionalType(
llvm::drop_begin(getCalleeOperands().getTypes(), isDirect ? 0 : 1),
getResultTypes());
call_interface_impl::printFunctionSignature(
p, getCalleeOperands().drop_front(isDirect ? 0 : 1).getTypes(),
getArgAttrsAttr(),
/*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
}

// <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
Expand All @@ -1659,7 +1678,8 @@ void InvokeOp::print(OpAsmPrinter &p) {
// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
// ( `vararg(` var-callee-type `)` )?
// ( `[` op-bundles-list `]` )?
// attribute-dict? `:` (type `,`)? function-type
// attribute-dict? `:` (type `,`)?
// function-type-with-argument-attributes
ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
SymbolRefAttr funcAttr;
Expand Down Expand Up @@ -1721,8 +1741,15 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();

// Parse the trailing type list and resolve the function operands.
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
SmallVector<DictionaryAttr> argAttrs;
SmallVector<DictionaryAttr> resultAttrs;
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
argAttrs, resultAttrs))
return failure();
call_interface_impl::addArgAndResultAttrs(
parser.getBuilder(), result, argAttrs, resultAttrs,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));

if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
opBundleOperandTypes,
getOpBundleSizesAttrName(result.name)))
Expand Down
34 changes: 34 additions & 0 deletions mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,34 @@ static void convertLinkerOptionsOp(ArrayAttr options,
linkerMDNode->addOperand(listMDNode);
}

static LogicalResult
convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
LLVM::ModuleTranslation &moduleTranslation) {
if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) {
for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
if (auto argAttrs = llvm::cast<DictionaryAttr>(argAttrsAttr)) {
FailureOr<llvm::AttrBuilder> attrBuilder =
moduleTranslation.convertParameterAttrs(argAttrs);
if (failed(attrBuilder))
return failure();
call->addParamAttrs(argIdx, *attrBuilder);
}
}
}

ArrayAttr resAttrsArray = callOp.getResAttrsAttr();
if (resAttrsArray && resAttrsArray.size() == 1) {
if (auto resAttrs = llvm::cast<DictionaryAttr>(resAttrsArray[0])) {
FailureOr<llvm::AttrBuilder> attrBuilder =
moduleTranslation.convertParameterAttrs(resAttrs);
if (failed(attrBuilder))
return failure();
call->addRetAttrs(*attrBuilder);
}
}
return success();
}

static LogicalResult
convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
Expand Down Expand Up @@ -265,6 +293,9 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
if (callOp.getWillReturnAttr())
call->addFnAttr(llvm::Attribute::WillReturn);

if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation)))
return failure();

if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
llvm::MemoryEffects memEffects =
llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
Expand Down Expand Up @@ -372,6 +403,9 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
operandsRef.drop_front(), opBundles);
}
result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
if (failed(
convertParameterAndResultAttrs(invOp, result, moduleTranslation)))
return failure();
moduleTranslation.mapBranch(invOp, result);
// InvokeOp can only have 0 or 1 result
if (invOp->getNumResults() != 0) {
Expand Down
37 changes: 37 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1706,6 +1706,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
if (failed(convertCallAttributes(callInst, callOp)))
return failure();
// Handle parameter and result attributes.
convertParameterAttributes(callInst, callOp, builder);
return callOp.getOperation();
}();

Expand Down Expand Up @@ -1786,6 +1788,9 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
return failure();

// Handle parameter and result attributes.
convertParameterAttributes(invokeInst, invokeOp, builder);

if (!invokeInst->getType()->isVoidTy())
mapValue(inst, invokeOp.getResults().front());
else
Expand Down Expand Up @@ -2149,6 +2154,38 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
}

void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
CallOpInterface callOp,
OpBuilder &builder) {
auto llvmAttrs = call->getAttributes();
SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
bool anyArgAttrs = false;
for (size_t i = 0, e = call->arg_size(); i < e; ++i) {
llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i));
if (llvmArgAttrsSet.back().hasAttributes())
anyArgAttrs = true;
}
auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
SmallVector<Attribute> attrs;
for (auto &dict : dictAttrs)
attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
return builder.getArrayAttr(attrs);
};
if (anyArgAttrs) {
SmallVector<DictionaryAttr> argAttrs;
for (auto &llvmArgAttrs : llvmArgAttrsSet)
argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
callOp.setArgAttrsAttr(getArrayAttr(argAttrs));
}

llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
if (!llvmResAttr.hasAttributes())
return;
SmallVector<DictionaryAttr, 1> resAttrs;
resAttrs.emplace_back(convertParameterAttribute(llvmResAttr, builder));
callOp.setResAttrsAttr(getArrayAttr(resAttrs));
}

template <typename Op>
static LogicalResult convertCallBaseAttributes(llvm::CallBase *inst, Op op) {
op.setCConv(convertCConvFromLLVM(inst->getCallingConv()));
Expand Down
Loading