Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2107,18 +2107,14 @@ def LLVM_CallIntrinsicOp
VariadicOfVariadic<LLVM_Type,
"op_bundle_sizes">:$op_bundle_operands,
DenseI32ArrayAttr:$op_bundle_sizes,
OptionalAttr<ArrayAttr>:$op_bundle_tags);
OptionalAttr<ArrayAttr>:$op_bundle_tags,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs);
let results = (outs Optional<LLVM_Type>:$results);
let llvmBuilder = [{
return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
}];
let assemblyFormat = [{
$intrin `(` $args `)`
( custom<OpBundles>($op_bundle_operands, type($op_bundle_operands),
$op_bundle_tags)^ )?
`:` functional-type($args, $results)
attr-dict
}];
let hasCustomAssemblyFormat = 1;

let builders = [
OpBuilder<(ins "StringAttr":$intrin, "ValueRange":$args)>,
Expand Down
8 changes: 7 additions & 1 deletion mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ class ModuleImport {
SmallVectorImpl<Value> &valuesOut,
SmallVectorImpl<NamedAttribute> &attrsOut);

/// Converts the parameter and result attributes in `argsAttr` and `resAttr`
/// and add them to the `callOp`.
void convertParameterAttributes(llvm::CallBase *call, ArrayAttr &argsAttr,
ArrayAttr &resAttr, OpBuilder &builder);

private:
/// Clears the accumulated state before processing a new region.
void clearRegionState() {
Expand Down Expand Up @@ -350,7 +355,8 @@ class ModuleImport {
DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
OpBuilder &builder);
/// Converts the parameter and result attributes attached to `call` and adds
/// them to the `callOp`.
/// them to the `callOp`. Implemented in terms of the other definition of
/// the public definition of convertParameterAttributes.
void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this function would be a good use case for a ArgumentAndResultAttributeInterface.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean In the future or as part of this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is definitely a separate PR.

OpBuilder &builder);
/// Converts the attributes attached to `inst` and adds them to the `op`.
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class ModuleTranslation {

/// Translates parameter attributes of a call and adds them to the returned
/// AttrBuilder. Returns failure if any of the translations failed.
FailureOr<llvm::AttrBuilder> convertParameterAttrs(CallOpInterface callOp,
FailureOr<llvm::AttrBuilder> convertParameterAttrs(mlir::Location loc,
DictionaryAttr paramAttrs);

/// Gets the named metadata in the LLVM IR module being constructed, creating
Expand Down
108 changes: 104 additions & 4 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3547,30 +3547,130 @@ void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
mlir::StringAttr intrin, mlir::ValueRange args) {
build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
FastmathFlagsAttr{},
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
/*res_attrs=*/{});
}

void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
mlir::StringAttr intrin, mlir::ValueRange args,
mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
fastMathFlags,
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
/*res_attrs=*/{});
}

void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
mlir::Type resultType, mlir::StringAttr intrin,
mlir::ValueRange args) {
build(builder, state, {resultType}, intrin, args, FastmathFlagsAttr{},
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
/*res_attrs=*/{});
}

void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
mlir::TypeRange resultTypes,
mlir::StringAttr intrin, mlir::ValueRange args,
mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
build(builder, state, resultTypes, intrin, args, fastMathFlags,
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
/*res_attrs=*/{});
}

ParseResult CallIntrinsicOp::parse(OpAsmParser &parser,
OperationState &result) {
StringAttr intrinAttr;
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
SmallVector<SmallVector<Type>> opBundleOperandTypes;
ArrayAttr opBundleTags;

// Parse intrinsic name
if (parser.parseCustomAttributeWithFallback(
intrinAttr, parser.getBuilder().getType<mlir::NoneType>())) {
return mlir::failure();
}
result.addAttribute(CallIntrinsicOp::getIntrinAttrName(result.name),
intrinAttr);

if (parser.parseLParen())
return mlir::failure();

// Parse the function arguments.
if (parser.parseOperandList(operands))
return mlir::failure();

if (parser.parseRParen())
return mlir::failure();

// Handle bundles.
SMLoc opBundlesLoc = parser.getCurrentLocation();
if (std::optional<ParseResult> result = parseOpBundles(
parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
result && failed(*result))
return failure();
if (opBundleTags && !opBundleTags.empty())
result.addAttribute(
CallIntrinsicOp::getOpBundleTagsAttrName(result.name).getValue(),
opBundleTags);

SmallVector<DictionaryAttr> argAttrs;
SmallVector<DictionaryAttr> resultAttrs;
if (parseCallTypeAndResolveOperands(parser, result, /*isDirect=*/true,
operands, argAttrs, resultAttrs))
return failure();
call_interface_impl::addArgAndResultAttrs(
parser.getBuilder(), result, argAttrs, resultAttrs,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));

// TODO: In CallOp, the attr dict happens *before* the call type.
// CallIntrinsicOp should mimic that, allowing most of this function to be
// shared between the two ops.
if (parser.parseOptionalAttrDict(result.attributes))
return mlir::failure();

if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
opBundleOperandTypes,
getOpBundleSizesAttrName(result.name)))
return failure();

int32_t numOpBundleOperands = 0;
for (const auto &operands : opBundleOperands)
numOpBundleOperands += operands.size();

result.addAttribute(
CallIntrinsicOp::getOperandSegmentSizeAttr(),
parser.getBuilder().getDenseI32ArrayAttr(
{static_cast<int32_t>(operands.size()), numOpBundleOperands}));

return mlir::success();
}

void CallIntrinsicOp::print(OpAsmPrinter &p) {
p << ' ';
p.printAttributeWithoutType(getIntrinAttr());

OperandRange args = getArgs();
p << "(" << args << ")";

// Operand bundles.
if (!getOpBundleOperands().empty()) {
p << ' ';
printOpBundles(p, *this, getOpBundleOperands(),
getOpBundleOperands().getTypes(), getOpBundleTagsAttr());
}
p << " : ";

// Reconstruct the MLIR function type from operand and result types.
call_interface_impl::printFunctionSignature(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yesterday you mentioned something can be bound only once which prevents us from using an assembly format. Was it the args?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, like mentioned in the other comment, perhaps it was due to lack of trying out ref().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parseFunctionSignature in CallInterfaces seems to be missing the operation as second argument (as is the function is probably incompatibly with the custom directive.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I tried to do:

  let assemblyFormat = [{
    $intrin `(` $args `)`
    ( custom<OpBundles>($op_bundle_operands, type($op_bundle_operands),
                        $op_bundle_tags)^ )?
    attr-dict `:`
    custom<IntrinCallArgsAndRet>($_state, ref($args), type($args), type($results))
  }];

But you cannot pass $state (OperationState) because expected variable to refer to an argument, region, result, or successor and that is needed (in the parser case) by parseCallTypeAndResolveOperands.

p, args.getTypes(), getArgAttrsAttr(),
/*isVariadic=*/false, getResultTypes(), getResAttrsAttr());

p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
{getOperandSegmentSizesAttrName(),
getOpBundleSizesAttrName(), getIntrinAttrName(),
getOpBundleTagsAttrName(), getArgAttrsAttrName(),
getResAttrsAttrName()});
}

//===----------------------------------------------------------------------===//
Expand Down
79 changes: 46 additions & 33 deletions mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,46 @@ convertOperandBundles(OperandRangeRange bundleOperands,
return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
}

static LogicalResult
convertParameterAndResultAttrs(mlir::Location loc, ArrayAttr argAttrsArray,
ArrayAttr resAttrsArray, llvm::CallBase *call,
LLVM::ModuleTranslation &moduleTranslation) {
if (argAttrsArray) {
for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr);
!argAttrs.empty()) {
FailureOr<llvm::AttrBuilder> attrBuilder =
moduleTranslation.convertParameterAttrs(loc, argAttrs);
if (failed(attrBuilder))
return failure();
call->addParamAttrs(argIdx, *attrBuilder);
}
}
}

if (resAttrsArray && resAttrsArray.size() > 0) {
if (resAttrsArray.size() != 1)
return mlir::emitError(loc, "llvm.func cannot have multiple results");
if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
!resAttrs.empty()) {
FailureOr<llvm::AttrBuilder> attrBuilder =
moduleTranslation.convertParameterAttrs(loc, resAttrs);
if (failed(attrBuilder))
return failure();
call->addRetAttrs(*attrBuilder);
}
}
return success();
}

static LogicalResult
convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
LLVM::ModuleTranslation &moduleTranslation) {
return convertParameterAndResultAttrs(
callOp.getLoc(), callOp.getArgAttrsAttr(), callOp.getResAttrsAttr(), call,
moduleTranslation);
}

/// Builder for LLVM_CallIntrinsicOp
static LogicalResult
convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
Expand Down Expand Up @@ -201,6 +241,12 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
fn, moduleTranslation.lookupValues(op.getArgs()),
convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(),
moduleTranslation));

if (failed(convertParameterAndResultAttrs(op.getLoc(), op.getArgAttrsAttr(),
op.getResAttrsAttr(), inst,
moduleTranslation)))
return failure();

if (op.getNumResults() == 1)
moduleTranslation.mapValue(op->getResults().front()) = inst;
return success();
Expand All @@ -224,39 +270,6 @@ static void convertLinkerOptionsOp(ArrayAttr options,
linkerMDNode->addOperand(listMDNode);
}

static LogicalResult
convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
LLVM::ModuleTranslation &moduleTranslation) {
if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) {
for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr);
!argAttrs.empty()) {
FailureOr<llvm::AttrBuilder> attrBuilder =
moduleTranslation.convertParameterAttrs(callOp, argAttrs);
if (failed(attrBuilder))
return failure();
call->addParamAttrs(argIdx, *attrBuilder);
}
}
}

ArrayAttr resAttrsArray = callOp.getResAttrsAttr();
if (resAttrsArray && resAttrsArray.size() > 0) {
if (resAttrsArray.size() != 1)
return mlir::emitError(callOp.getLoc(),
"llvm.func cannot have multiple results");
if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
!resAttrs.empty()) {
FailureOr<llvm::AttrBuilder> attrBuilder =
moduleTranslation.convertParameterAttrs(callOp, resAttrs);
if (failed(attrBuilder))
return failure();
call->addRetAttrs(*attrBuilder);
}
}
return success();
}

static LogicalResult
convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic(

moduleImport.setFastmathFlagsAttr(inst, op);

ArrayAttr argsAttr, resAttr;
moduleImport.convertParameterAttributes(inst, argsAttr, resAttr, builder);
op.setArgAttrsAttr(argsAttr);
op.setResAttrsAttr(resAttr);

// Update importer tracking of results.
unsigned numRes = op.getNumResults();
if (numRes == 1)
Expand Down
16 changes: 13 additions & 3 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2213,7 +2213,8 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
}

void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
CallOpInterface callOp,
ArrayAttr &argsAttr,
ArrayAttr &resAttr,
OpBuilder &builder) {
llvm::AttributeList llvmAttrs = call->getAttributes();
SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
Expand All @@ -2233,14 +2234,23 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
SmallVector<DictionaryAttr> argAttrs;
for (auto &llvmArgAttrs : llvmArgAttrsSet)
argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
callOp.setArgAttrsAttr(getArrayAttr(argAttrs));
argsAttr = getArrayAttr(argAttrs);
}

llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
if (!llvmResAttr.hasAttributes())
return;
DictionaryAttr resAttrs = convertParameterAttribute(llvmResAttr, builder);
callOp.setResAttrsAttr(getArrayAttr({resAttrs}));
resAttr = getArrayAttr({resAttrs});
}

void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
CallOpInterface callOp,
OpBuilder &builder) {
ArrayAttr argsAttr, resAttr;
convertParameterAttributes(call, argsAttr, resAttr, builder);
callOp.setArgAttrsAttr(argsAttr);
callOp.setResAttrsAttr(resAttr);
}

template <typename Op>
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1696,10 +1696,9 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
}

FailureOr<llvm::AttrBuilder>
ModuleTranslation::convertParameterAttrs(CallOpInterface callOp,
ModuleTranslation::convertParameterAttrs(Location loc,
DictionaryAttr paramAttrs) {
llvm::AttrBuilder attrBuilder(llvmModule->getContext());
Location loc = callOp.getLoc();
auto attrNameToKindMapping = getAttrNameToKindMapping();

for (auto namedAttr : paramAttrs) {
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Dialect/LLVMIR/call-intrin.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,12 @@ llvm.func @bad_args() {
%res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i64) -> vector<4xf32> {fastmathFlags = #llvm.fastmath<reassoc>}
llvm.return
}

// -----

// CHECK-LABEL: intrinsic_call_arg_attrs
llvm.func @intrinsic_call_arg_attrs(%arg0: i32) -> i32 {
// CHECK: call i32 @llvm.riscv.sha256sig0(i32 signext %{{.*}})
%0 = llvm.call_intrinsic "llvm.riscv.sha256sig0"(%arg0) : (i32 {llvm.signext}) -> (i32)
llvm.return %0 : i32
}
11 changes: 11 additions & 0 deletions mlir/test/Target/LLVMIR/Import/intrinsic-unregistered.ll
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,14 @@ define void @lround_test(float %0, double %1) {
%3 = call i32 @llvm.lround.i32.f32(float %0)
ret void
}

; // -----

declare i32 @llvm.riscv.sha256sig0(i32)

; CHECK-LABEL: test_intrin_arg_attr
define signext i32 @test_intrin_arg_attr(i32 signext %a) nounwind {
; CHECK: llvm.call_intrinsic "llvm.riscv.sha256sig0"({{.*}}) : (i32 {llvm.signext}) -> i32
%val = call i32 @llvm.riscv.sha256sig0(i32 signext %a)
ret i32 %val
}