Skip to content

Commit e6598b0

Browse files
committed
Revert "Revert "[mlir][linalg] Replace "string" iterator_types attr with enums in LinalgInterface.""
With python code fixed. This reverts commit 4128090.
1 parent 07665e7 commit e6598b0

File tree

34 files changed

+391
-332
lines changed

34 files changed

+391
-332
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef LINALG_BASE
1414
#define LINALG_BASE
1515

16+
include "mlir/Dialect/Utils/StructuredOpsUtils.td"
1617
include "mlir/Dialect/Linalg/IR/LinalgEnums.td"
1718
include "mlir/IR/EnumAttr.td"
1819
include "mlir/IR/OpBase.td"
@@ -71,4 +72,10 @@ def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
7172
let assemblyFormat = "`<` $value `>`";
7273
}
7374

75+
def IteratorTypeEnum : EnumAttr<Linalg_Dialect, IteratorType, "iterator_type"> {
76+
let assemblyFormat = "`<` $value `>`";
77+
}
78+
def IteratorTypeArrayAttr : TypedArrayAttrBase<IteratorTypeEnum,
79+
"Iterator type should be an enum.">;
80+
7481
#endif // LINALG_BASE

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
namespace mlir {
2727
namespace linalg {
28+
class IteratorTypeAttr;
2829
class LinalgOp;
2930

3031
namespace detail {

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
193193
/*args=*/(ins),
194194
/*methodBody=*/"",
195195
/*defaultImplementation=*/[{
196-
return getNumIterators(getParallelIteratorTypeName(),
197-
$_op.getIteratorTypesArray());
196+
return llvm::count($_op.getIteratorTypesArray(),
197+
utils::IteratorType::parallel);
198198
}]
199199
>,
200200
InterfaceMethod<
@@ -207,7 +207,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
207207
/*methodBody=*/"",
208208
/*defaultImplementation=*/[{
209209
return findPositionsOfType($_op.getIteratorTypesArray(),
210-
getParallelIteratorTypeName(), res);
210+
utils::IteratorType::parallel, res);
211211
}]
212212
>,
213213
InterfaceMethod<
@@ -219,8 +219,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
219219
/*args=*/(ins),
220220
/*methodBody=*/"",
221221
/*defaultImplementation=*/[{
222-
return getNumIterators(getReductionIteratorTypeName(),
223-
$_op.getIteratorTypesArray());
222+
return llvm::count($_op.getIteratorTypesArray(),
223+
utils::IteratorType::reduction);
224224
}]
225225
>,
226226
InterfaceMethod<
@@ -233,33 +233,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
233233
/*methodBody=*/"",
234234
/*defaultImplementation=*/[{
235235
return findPositionsOfType($_op.getIteratorTypesArray(),
236-
getReductionIteratorTypeName(), res);
237-
}]
238-
>,
239-
InterfaceMethod<
240-
/*desc=*/[{
241-
Return the number of window loops.
242-
}],
243-
/*retTy=*/"unsigned",
244-
/*methodName=*/"getNumWindowLoops",
245-
/*args=*/(ins),
246-
/*methodBody=*/"",
247-
/*defaultImplementation=*/[{
248-
return getNumIterators(getWindowIteratorTypeName(),
249-
$_op.getIteratorTypesArray());
250-
}]
251-
>,
252-
InterfaceMethod<
253-
/*desc=*/[{
254-
Return the dims that are window loops.
255-
}],
256-
/*retTy=*/"void",
257-
/*methodName=*/"getWindowDims",
258-
/*args=*/(ins "SmallVectorImpl<unsigned> &":$res),
259-
/*methodBody=*/"",
260-
/*defaultImplementation=*/[{
261-
return findPositionsOfType($_op.getIteratorTypesArray(),
262-
getWindowIteratorTypeName(), res);
236+
utils::IteratorType::reduction, res);
263237
}]
264238
>,
265239
InterfaceMethod<
@@ -271,7 +245,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
271245
/*args=*/(ins),
272246
/*methodBody=*/"",
273247
/*defaultImplementation=*/[{
274-
return getNumIterators($_op.getIteratorTypesArray());
248+
return $_op.getIteratorTypesArray().size();
275249
}]
276250
>,
277251
InterfaceMethod<
@@ -286,7 +260,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
286260
/*defaultImplementation=*/[{
287261
auto iters = $_op.getIteratorTypesArray();
288262
return iters.size() == 1 &&
289-
getNumIterators(getReductionIteratorTypeName(), iters) == 1;
263+
llvm::count(iters, utils::IteratorType::reduction) == 1;
290264
}]>,
291265
//===------------------------------------------------------------------===//
292266
// Input and Init arguments handling.
@@ -506,12 +480,14 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
506480
can be infered from other parameters and in such cases default
507481
getIteratorTypesArray should be overriden.
508482
}],
509-
/*retTy=*/"SmallVector<StringRef>",
483+
/*retTy=*/"SmallVector<utils::IteratorType>",
510484
/*methodName=*/"getIteratorTypesArray",
511485
/*args=*/(ins),
512486
/*methodBody=*/"",
513487
/*defaultImplementation=*/[{
514-
auto range = $_op.getIteratorTypes().template getAsValueRange<StringAttr>();
488+
auto range = $_op.getIteratorTypes()
489+
.template getAsValueRange<IteratorTypeAttr,
490+
utils::IteratorType>();
515491
return {range.begin(), range.end()};
516492
}]
517493
>,
@@ -767,10 +743,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
767743
LogicalResult reifyResultShapes(OpBuilder &b,
768744
ReifiedRankedShapedTypeDims &reifiedReturnShapes);
769745

770-
SmallVector<StringRef> getIteratorTypeNames() {
771-
return getIteratorTypesArray();
772-
}
773-
774746
//========================================================================//
775747
// Forwarding functions to access interface methods from the
776748
// DestinationStyleOpInterface.

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
163163
let arguments = (ins Variadic<AnyType>:$inputs,
164164
Variadic<AnyShaped>:$outputs,
165165
AffineMapArrayAttr:$indexing_maps,
166-
ArrayAttr:$iterator_types,
166+
IteratorTypeArrayAttr:$iterator_types,
167167
OptionalAttr<StrAttr>:$doc,
168168
OptionalAttr<StrAttr>:$library_call);
169169
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
@@ -178,22 +178,22 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
178178
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
179179
OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
180180
"ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
181-
"ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc,
181+
"ArrayRef<utils::IteratorType>":$iteratorTypes, "StringRef":$doc,
182182
"StringRef":$libraryCall,
183183
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
184184
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
185185
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
186-
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
186+
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<utils::IteratorType>":$iteratorTypes,
187187
"StringRef":$doc, "StringRef":$libraryCall,
188188
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
189189
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
190190
OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
191191
"ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
192-
"ArrayRef<StringRef>":$iteratorTypes,
192+
"ArrayRef<utils::IteratorType>":$iteratorTypes,
193193
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
194194
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
195195
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
196-
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
196+
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<utils::IteratorType>":$iteratorTypes,
197197
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
198198
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
199199
];
@@ -275,7 +275,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
275275

276276
let extraClassDeclaration = structuredOpsBaseDecls # [{
277277
// Implement functions necessary for LinalgStructuredInterface.
278-
SmallVector<StringRef> getIteratorTypesArray();
278+
SmallVector<utils::IteratorType> getIteratorTypesArray();
279279
ArrayAttr getIndexingMaps();
280280
std::string getLibraryCallName() {
281281
return "op_has_no_registered_library_name";
@@ -356,7 +356,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
356356

357357
let extraClassDeclaration = structuredOpsBaseDecls # [{
358358
// Declare functions necessary for LinalgStructuredInterface.
359-
SmallVector<StringRef> getIteratorTypesArray();
359+
SmallVector<utils::IteratorType> getIteratorTypesArray();
360360
ArrayAttr getIndexingMaps();
361361
std::string getLibraryCallName() {
362362
return "op_has_no_registered_library_name";
@@ -426,7 +426,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
426426

427427
let extraClassDeclaration = structuredOpsBaseDecls # [{
428428
// Declare functions necessary for LinalgStructuredInterface.
429-
SmallVector<StringRef> getIteratorTypesArray();
429+
SmallVector<utils::IteratorType> getIteratorTypesArray();
430430
ArrayAttr getIndexingMaps();
431431
std::string getLibraryCallName() {
432432
return "op_has_no_registered_library_name";
@@ -502,7 +502,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
502502

503503
let extraClassDeclaration = structuredOpsBaseDecls # [{
504504
// Declare functions necessary for LinalgStructuredInterface.
505-
SmallVector<StringRef> getIteratorTypesArray();
505+
SmallVector<utils::IteratorType> getIteratorTypesArray();
506506
ArrayAttr getIndexingMaps();
507507
std::string getLibraryCallName() {
508508
return "op_has_no_registered_library_name";

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ bool hasOnlyScalarElementwiseOp(Region &r);
4242
bool isElementwise(LinalgOp op);
4343

4444
/// Check if iterator type has "parallel" semantics.
45-
bool isParallelIterator(StringRef iteratorType);
45+
bool isParallelIterator(utils::IteratorType iteratorType);
4646

4747
/// Check if iterator type has "reduction" semantics.
48-
bool isReductionIterator(StringRef iteratorType);
48+
bool isReductionIterator(utils::IteratorType iteratorType);
4949

5050
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
5151
/// the type of `source`.
@@ -480,7 +480,8 @@ struct RegionMatcher {
480480
template <typename LoopTy>
481481
struct GenerateLoopNest {
482482
static void doit(OpBuilder &b, Location loc, ArrayRef<Range> loopRanges,
483-
LinalgOp linalgOp, ArrayRef<StringRef> iteratorTypes,
483+
LinalgOp linalgOp,
484+
ArrayRef<utils::IteratorType> iteratorTypes,
484485
function_ref<scf::ValueVector(OpBuilder &, Location,
485486
ValueRange, ValueRange)>
486487
bodyBuilderFn,

mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ namespace mlir {
2222
namespace tosa {
2323

2424
// Creates a SmallVector of Stringrefs for N parallel loops
25-
SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops);
25+
SmallVector<utils::IteratorType>
26+
getNParallelLoopsAttrs(unsigned nParallelLoops);
2627

2728
// Takes a vector of values and condenses them to a vector with no gaps.
2829
SmallVector<Value> condenseValues(const SmallVector<Value> &values);

mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h

Lines changed: 12 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include "mlir/IR/BuiltinAttributes.h"
2222
#include "mlir/IR/Location.h"
2323
#include "mlir/Support/LLVM.h"
24-
#include "llvm/ADT/StringRef.h"
2524

2625
// Pull in all enum type definitions and utility function declarations.
2726
#include "mlir/Dialect/Utils/DialectUtilsEnums.h.inc"
@@ -48,42 +47,9 @@ bool isColumnMajorMatmul(ArrayAttr indexingMaps);
4847
/// the reduction.
4948
bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
5049

51-
/// Use to encode that a particular iterator type has parallel semantics.
52-
constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
53-
54-
/// Use to encode that a particular iterator type has reduction semantics.
55-
constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }
56-
57-
/// Use to encode that a particular iterator type has window semantics.
58-
constexpr StringRef getWindowIteratorTypeName() { return "window"; }
59-
60-
/// Use to encode that a particular iterator type has window semantics.
61-
inline ArrayRef<StringRef> getAllIteratorTypeNames() {
62-
static constexpr StringRef names[3] = {getParallelIteratorTypeName(),
63-
getReductionIteratorTypeName(),
64-
getWindowIteratorTypeName()};
65-
return llvm::makeArrayRef(names);
66-
}
67-
68-
/// Returns the iterator of a certain type.
69-
inline unsigned getNumIterators(StringRef name,
70-
ArrayRef<StringRef> iteratorTypes) {
71-
auto names = getAllIteratorTypeNames();
72-
(void)names;
73-
assert(llvm::is_contained(names, name));
74-
return llvm::count(iteratorTypes, name);
75-
}
76-
77-
inline unsigned getNumIterators(ArrayRef<StringRef> iteratorTypes) {
78-
unsigned res = 0;
79-
for (auto n : getAllIteratorTypeNames())
80-
res += getNumIterators(n, iteratorTypes);
81-
return res;
82-
}
83-
8450
/// Return positions in `iteratorTypes` that match `iteratorTypeName`.
85-
inline void findPositionsOfType(ArrayRef<StringRef> iteratorTypes,
86-
StringRef iteratorTypeName,
51+
inline void findPositionsOfType(ArrayRef<utils::IteratorType> iteratorTypes,
52+
utils::IteratorType iteratorTypeName,
8753
SmallVectorImpl<unsigned> &res) {
8854
for (const auto &en : llvm::enumerate(iteratorTypes)) {
8955
if (en.value() == iteratorTypeName)
@@ -94,29 +60,28 @@ inline void findPositionsOfType(ArrayRef<StringRef> iteratorTypes,
9460
/// Helper StructuredGenerator class to manipulate and rewrite ops with
9561
/// `StructuredOpInterface`. This is templated for now because VectorOps do not
9662
/// yet implement the StructuredOpInterface itself.
97-
template <typename StructuredOpInterface>
63+
template <typename StructuredOpInterface, typename IteratorTypeT>
9864
class StructuredGenerator {
9965
public:
10066
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
10167

10268
struct IteratorType {
103-
IteratorType(StringRef strRef) : strRef(strRef) {}
104-
bool isOfType(StringRef typeName) const { return typeName == strRef; }
105-
StringRef strRef;
69+
IteratorType(IteratorTypeT iter) : iter(iter) {}
70+
bool isOfType(IteratorTypeT expectedIter) const {
71+
return expectedIter == iter;
72+
}
73+
IteratorTypeT iter;
10674
};
10775
struct Par : public IteratorType {
108-
Par() : IteratorType(getParallelIteratorTypeName()) {}
76+
Par() : IteratorType(IteratorTypeT::parallel) {}
10977
};
11078
struct Red : public IteratorType {
111-
Red() : IteratorType(getReductionIteratorTypeName()) {}
112-
};
113-
struct Win : public IteratorType {
114-
Win() : IteratorType(getWindowIteratorTypeName()) {}
79+
Red() : IteratorType(IteratorTypeT::reduction) {}
11580
};
11681

11782
StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
11883
: builder(builder), ctx(op.getContext()), loc(op.getLoc()),
119-
iterators(op.getIteratorTypeNames()), maps(op.getIndexingMapsArray()),
84+
iterators(op.getIteratorTypesArray()), maps(op.getIndexingMapsArray()),
12085
op(op) {}
12186

12287
bool iters(ArrayRef<IteratorType> its) {
@@ -138,7 +103,7 @@ class StructuredGenerator {
138103
OpBuilder &builder;
139104
MLIRContext *ctx;
140105
Location loc;
141-
SmallVector<StringRef> iterators;
106+
SmallVector<IteratorTypeT> iterators;
142107
SmallVector<AffineMap, 4> maps;
143108
Operation *op;
144109
};

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,11 @@ def Vector_ContractionOp :
269269
return CombiningKind::ADD;
270270
}
271271

272-
// Returns iterator types in string format.
273-
SmallVector<StringRef> getIteratorTypeNames() {
274-
return llvm::to_vector(
275-
llvm::map_range(getIteratorTypes(), [](Attribute a) {
276-
return stringifyIteratorType(a.cast<IteratorTypeAttr>().getValue());
277-
}));
272+
SmallVector<IteratorType> getIteratorTypesArray() {
273+
auto range =
274+
getIteratorTypes()
275+
.template getAsValueRange<IteratorTypeAttr, IteratorType>();
276+
return {range.begin(), range.end()};
278277
}
279278
}];
280279

0 commit comments

Comments
 (0)