Skip to content

Commit b652f02

Browse files
committed
[mlir] share argument attributes interface between calls and callables
1 parent b3924cb commit b652f02

File tree

14 files changed

+359
-234
lines changed

14 files changed

+359
-234
lines changed

mlir/docs/Interfaces.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -753,20 +753,22 @@ interface section goes as follows:
753753
- (`C++ class` -- `ODS class`(if applicable))
754754

755755
##### CallInterfaces
756-
756+
* `OpWithArgumentAttributesInterface` - Used to represent operations that may
757+
carry argument and result attributes. It is inherited by both
758+
CallOpInterface and CallableOpInterface.
759+
- `ArrayAttr getArgAttrsAttr()`
760+
- `ArrayAttr getResAttrsAttr()`
761+
- `void setArgAttrsAttr(ArrayAttr)`
762+
- `void setResAttrsAttr(ArrayAttr)`
763+
- `Attribute removeArgAttrsAttr()`
764+
- `Attribute removeResAttrsAttr()`
757765
* `CallOpInterface` - Used to represent operations like 'call'
758766
- `CallInterfaceCallable getCallableForCallee()`
759767
- `void setCalleeFromCallable(CallInterfaceCallable)`
760768
* `CallableOpInterface` - Used to represent the target callee of call.
761769
- `Region * getCallableRegion()`
762770
- `ArrayRef<Type> getArgumentTypes()`
763771
- `ArrayRef<Type> getResultsTypes()`
764-
- `ArrayAttr getArgAttrsAttr()`
765-
- `ArrayAttr getResAttrsAttr()`
766-
- `void setArgAttrsAttr(ArrayAttr)`
767-
- `void setResAttrsAttr(ArrayAttr)`
768-
- `Attribute removeArgAttrsAttr()`
769-
- `Attribute removeResAttrsAttr()`
770772

771773
##### RegionKindInterfaces
772774

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
//===- CallImplementation.h - Call and Callable Op utilities ----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file provides utility functions for implementing call-like and
10+
// callable-like operations, in particular, parsing, printing and verification
11+
// components common to these operations.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef MLIR_INTERFACES_CALLIMPLEMENTATION_H
16+
#define MLIR_INTERFACES_CALLIMPLEMENTATION_H
17+
18+
#include "mlir/IR/OpImplementation.h"
19+
#include "mlir/Interfaces/CallInterfaces.h"
20+
21+
namespace mlir {
22+
23+
class OpWithArgumentAttributesInterface;
24+
25+
namespace call_interface_impl {
26+
27+
/// Parse a function or call result list.
28+
///
29+
/// function-result-list ::= function-result-list-parens
30+
/// | non-function-type
31+
/// function-result-list-parens ::= `(` `)`
32+
/// | `(` function-result-list-no-parens `)`
33+
/// function-result-list-no-parens ::= function-result (`,` function-result)*
34+
/// function-result ::= type attribute-dict?
35+
///
36+
ParseResult
37+
parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
38+
SmallVectorImpl<DictionaryAttr> &resultAttrs);
39+
40+
/// Parses a function signature using `parser`. This does not deal with function
41+
/// signatures containing SSA region arguments (to parse these signatures, use
42+
/// function_interface_impl::parseFunctionSignature). When
43+
/// `mustParseEmptyResult`, `-> ()` is expected when there is no result type.
44+
///
45+
/// no-ssa-function-signature ::= `(` no-ssa-function-arg-list `)`
46+
/// -> function-result-list
47+
/// no-ssa-function-arg-list ::= no-ssa-function-arg
48+
/// (`,` no-ssa-function-arg)*
49+
/// no-ssa-function-arg ::= type attribute-dict?
50+
ParseResult parseFunctionSignature(OpAsmParser &parser,
51+
SmallVectorImpl<Type> &argTypes,
52+
SmallVectorImpl<DictionaryAttr> &argAttrs,
53+
SmallVectorImpl<Type> &resultTypes,
54+
SmallVectorImpl<DictionaryAttr> &resultAttrs,
55+
bool mustParseEmptyResult = true);
56+
57+
/// Print a function signature for a call or callable operation. If a body
58+
/// region is provided, the SSA arguments are printed in the signature. When
59+
/// `printEmptyResult` is false, `-> function-result-list` is omitted when
60+
/// `resultTypes` is empty.
61+
///
62+
/// function-signature ::= ssa-function-signature
63+
/// | no-ssa-function-signature
64+
/// ssa-function-signature ::= `(` ssa-function-arg-list `)`
65+
/// -> function-result-list
66+
/// ssa-function-arg-list ::= ssa-function-arg (`,` ssa-function-arg)*
67+
/// ssa-function-arg ::= `%`name `:` type attribute-dict?
68+
void printFunctionSignature(OpAsmPrinter &p,
69+
OpWithArgumentAttributesInterface op,
70+
TypeRange argTypes, bool isVariadic,
71+
TypeRange resultTypes, Region *body = nullptr,
72+
bool printEmptyResult = true);
73+
74+
/// Adds argument and result attributes, provided as `argAttrs` and
75+
/// `resultAttrs` arguments, to the list of operation attributes in `result`.
76+
/// Internally, argument and result attributes are stored as dict attributes
77+
/// with special names given by getResultAttrName, getArgumentAttrName.
78+
void addArgAndResultAttrs(Builder &builder, OperationState &result,
79+
ArrayRef<DictionaryAttr> argAttrs,
80+
ArrayRef<DictionaryAttr> resultAttrs,
81+
StringAttr argAttrsName, StringAttr resAttrsName);
82+
void addArgAndResultAttrs(Builder &builder, OperationState &result,
83+
ArrayRef<OpAsmParser::Argument> args,
84+
ArrayRef<DictionaryAttr> resultAttrs,
85+
StringAttr argAttrsName, StringAttr resAttrsName);
86+
87+
} // namespace call_interface_impl
88+
89+
} // namespace mlir
90+
91+
#endif // MLIR_INTERFACES_CALLIMPLEMENTATION_H

mlir/include/mlir/Interfaces/CallInterfaces.td

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,65 @@
1717

1818
include "mlir/IR/OpBase.td"
1919

20+
21+
/// Interface for operations with arguments attributes (both call-like
22+
/// and callable operations).
23+
def OpWithArgumentAttributesInterface : OpInterface<"OpWithArgumentAttributesInterface"> {
24+
let description = [{
25+
A call-like or callable operation that may define attributes for its arguments.
26+
}];
27+
let cppNamespace = "::mlir";
28+
let methods = [
29+
InterfaceMethod<[{
30+
Get the array of argument attribute dictionaries. The method should
31+
return an array attribute containing only dictionary attributes equal in
32+
number to the number of arguments. Alternatively, the method can
33+
return null to indicate that the region has no argument attributes.
34+
}],
35+
"::mlir::ArrayAttr", "getArgAttrsAttr", (ins),
36+
/*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
37+
InterfaceMethod<[{
38+
Get the array of result attribute dictionaries. The method should return
39+
an array attribute containing only dictionary attributes equal in number
40+
to the number of results. Alternatively, the method can return
41+
null to indicate that the region has no result attributes.
42+
}],
43+
"::mlir::ArrayAttr", "getResAttrsAttr", (ins),
44+
/*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
45+
InterfaceMethod<[{
46+
Set the array of argument attribute dictionaries.
47+
}],
48+
"void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
49+
/*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
50+
InterfaceMethod<[{
51+
Set the array of result attribute dictionaries.
52+
}],
53+
"void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
54+
/*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
55+
InterfaceMethod<[{
56+
Remove the array of argument attribute dictionaries. This is the same as
57+
setting all argument attributes to an empty dictionary. The method should
58+
return the removed attribute.
59+
}],
60+
"::mlir::Attribute", "removeArgAttrsAttr", (ins),
61+
/*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
62+
InterfaceMethod<[{
63+
Remove the array of result attribute dictionaries. This is the same as
64+
setting all result attributes to an empty dictionary. The method should
65+
return the removed attribute.
66+
}],
67+
"::mlir::Attribute", "removeResAttrsAttr", (ins),
68+
/*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
69+
];
70+
}
71+
2072
// `CallInterfaceCallable`: This is a type used to represent a single callable
2173
// region. A callable is either a symbol, or an SSA value, that is referenced by
2274
// a call-like operation. This represents the destination of the call.
2375

2476
/// Interface for call-like operations.
25-
def CallOpInterface : OpInterface<"CallOpInterface"> {
77+
def CallOpInterface : OpInterface<"CallOpInterface",
78+
[OpWithArgumentAttributesInterface]> {
2679
let description = [{
2780
A call-like operation is one that transfers control from one sub-routine to
2881
another. These operations may be traditional direct calls `call @foo`, or
@@ -85,7 +138,8 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
85138
}
86139

87140
/// Interface for callable operations.
88-
def CallableOpInterface : OpInterface<"CallableOpInterface"> {
141+
def CallableOpInterface : OpInterface<"CallableOpInterface",
142+
[OpWithArgumentAttributesInterface]> {
89143
let description = [{
90144
A callable operation is one who represents a potential sub-routine, and may
91145
be a target for a call-like operation (those providing the CallOpInterface
@@ -113,47 +167,6 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
113167
allow for this method may be called on function declarations).
114168
}],
115169
"::llvm::ArrayRef<::mlir::Type>", "getResultTypes">,
116-
117-
InterfaceMethod<[{
118-
Get the array of argument attribute dictionaries. The method should
119-
return an array attribute containing only dictionary attributes equal in
120-
number to the number of region arguments. Alternatively, the method can
121-
return null to indicate that the region has no argument attributes.
122-
}],
123-
"::mlir::ArrayAttr", "getArgAttrsAttr", (ins),
124-
/*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
125-
InterfaceMethod<[{
126-
Get the array of result attribute dictionaries. The method should return
127-
an array attribute containing only dictionary attributes equal in number
128-
to the number of region results. Alternatively, the method can return
129-
null to indicate that the region has no result attributes.
130-
}],
131-
"::mlir::ArrayAttr", "getResAttrsAttr", (ins),
132-
/*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
133-
InterfaceMethod<[{
134-
Set the array of argument attribute dictionaries.
135-
}],
136-
"void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
137-
/*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
138-
InterfaceMethod<[{
139-
Set the array of result attribute dictionaries.
140-
}],
141-
"void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
142-
/*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
143-
InterfaceMethod<[{
144-
Remove the array of argument attribute dictionaries. This is the same as
145-
setting all argument attributes to an empty dictionary. The method should
146-
return the removed attribute.
147-
}],
148-
"::mlir::Attribute", "removeArgAttrsAttr", (ins),
149-
/*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
150-
InterfaceMethod<[{
151-
Remove the array of result attribute dictionaries. This is the same as
152-
setting all result attributes to an empty dictionary. The method should
153-
return the removed attribute.
154-
}],
155-
"::mlir::Attribute", "removeResAttrsAttr", (ins),
156-
/*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
157170
];
158171
}
159172

mlir/include/mlir/Interfaces/FunctionImplementation.h

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#define MLIR_IR_FUNCTIONIMPLEMENTATION_H_
1717

1818
#include "mlir/IR/OpImplementation.h"
19+
#include "mlir/Interfaces/CallImplementation.h"
1920
#include "mlir/Interfaces/FunctionInterfaces.h"
2021

2122
namespace mlir {
@@ -33,19 +34,6 @@ class VariadicFlag {
3334
bool variadic;
3435
};
3536

36-
/// Adds argument and result attributes, provided as `argAttrs` and
37-
/// `resultAttrs` arguments, to the list of operation attributes in `result`.
38-
/// Internally, argument and result attributes are stored as dict attributes
39-
/// with special names given by getResultAttrName, getArgumentAttrName.
40-
void addArgAndResultAttrs(Builder &builder, OperationState &result,
41-
ArrayRef<DictionaryAttr> argAttrs,
42-
ArrayRef<DictionaryAttr> resultAttrs,
43-
StringAttr argAttrsName, StringAttr resAttrsName);
44-
void addArgAndResultAttrs(Builder &builder, OperationState &result,
45-
ArrayRef<OpAsmParser::Argument> args,
46-
ArrayRef<DictionaryAttr> resultAttrs,
47-
StringAttr argAttrsName, StringAttr resAttrsName);
48-
4937
/// Callback type for `parseFunctionOp`, the callback should produce the
5038
/// type that will be associated with a function-like operation from lists of
5139
/// function arguments and results, VariadicFlag indicates whether the function
@@ -84,9 +72,13 @@ void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
8472

8573
/// Prints the signature of the function-like operation `op`. Assumes `op` has
8674
/// is a FunctionOpInterface and has passed verification.
87-
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
88-
ArrayRef<Type> argTypes, bool isVariadic,
89-
ArrayRef<Type> resultTypes);
75+
inline void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
76+
ArrayRef<Type> argTypes, bool isVariadic,
77+
ArrayRef<Type> resultTypes) {
78+
call_interface_impl::printFunctionSignature(p, op, argTypes, isVariadic,
79+
resultTypes, &op->getRegion(0),
80+
/*printEmptyResult=*/false);
81+
}
9082

9183
/// Prints the list of function prefixed with the "attributes" keyword. The
9284
/// attributes with names listed in "elided" as well as those used by the

mlir/lib/Dialect/Async/IR/Async.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
308308
if (argAttrs.empty())
309309
return;
310310
assert(type.getNumInputs() == argAttrs.size());
311-
function_interface_impl::addArgAndResultAttrs(
311+
call_interface_impl::addArgAndResultAttrs(
312312
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
313313
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
314314
}

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
529529
if (argAttrs.empty())
530530
return;
531531
assert(type.getNumInputs() == argAttrs.size());
532-
function_interface_impl::addArgAndResultAttrs(
532+
call_interface_impl::addArgAndResultAttrs(
533533
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
534534
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
535535
}

mlir/lib/Dialect/Func/IR/FuncOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
190190
if (argAttrs.empty())
191191
return;
192192
assert(type.getNumInputs() == argAttrs.size());
193-
function_interface_impl::addArgAndResultAttrs(
193+
call_interface_impl::addArgAndResultAttrs(
194194
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
195195
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
196196
}

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1487,7 +1487,7 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
14871487
result.addAttribute(getFunctionTypeAttrName(result.name),
14881488
TypeAttr::get(type));
14891489

1490-
function_interface_impl::addArgAndResultAttrs(
1490+
call_interface_impl::addArgAndResultAttrs(
14911491
builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
14921492
getResAttrsAttrName(result.name));
14931493

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2510,7 +2510,7 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
25102510

25112511
assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() &&
25122512
"expected as many argument attribute lists as arguments");
2513-
function_interface_impl::addArgAndResultAttrs(
2513+
call_interface_impl::addArgAndResultAttrs(
25142514
builder, result, argAttrs, /*resultAttrs=*/std::nullopt,
25152515
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
25162516
}
@@ -2636,7 +2636,7 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
26362636

26372637
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
26382638
return failure();
2639-
function_interface_impl::addArgAndResultAttrs(
2639+
call_interface_impl::addArgAndResultAttrs(
26402640
parser.getBuilder(), result, entryArgs, resultAttrs,
26412641
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
26422642

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
940940

941941
// Add the attributes to the function arguments.
942942
assert(resultAttrs.size() == resultTypes.size());
943-
function_interface_impl::addArgAndResultAttrs(
943+
call_interface_impl::addArgAndResultAttrs(
944944
builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
945945
getResAttrsAttrName(result.name));
946946

0 commit comments

Comments
 (0)