Skip to content

Commit 8192475

Browse files
committed
[mlir][vector] Allow integer indices in vector.extract/insert ops
`vector.extract` and `vector.insert` can currently take an `i64` constant or an `index` type value as indices. The `index` type will usually lower to an `i32` or `i64` type. However, we are often indexing really small vector dimensions where smaller integers could be used. This PR extends both ops to accept any integer value as indices. For example: ``` %0 = vector.extract %arg0[%i32_idx : i32] : vector<8x16xf32> from vector<4x8x16xf32> %1 = vector.extract %arg0[%i8_idx, %i8_idx : i8] : vector<16xf32> from vector<4x8x16xf32> %2 = vector.extract %arg0[%i8_idx, 5, %i8_idx : i8] : f32 from vector<4x8x16xf32> ``` This led to some changes to the ops' parser and printer. When a value index is provided, the index type is printed as part of the index list. All the value indices provided must match that type. When no value index is provided, no index type is printed.
1 parent c93e001 commit 8192475

File tree

22 files changed

+353
-163
lines changed

22 files changed

+353
-163
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -695,14 +695,14 @@ def Vector_ExtractOp :
695695
%1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32>
696696
%2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32>
697697
%3 = vector.extract %1[]: vector<f32> from vector<f32>
698-
%4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
699-
%5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
698+
%4 = vector.extract %0[%a, %b, %c : index] : f32 from vector<4x8x16xf32>
699+
%5 = vector.extract %0[2, %b : index] : vector<16xf32> from vector<4x8x16xf32>
700700
```
701701
}];
702702

703703
let arguments = (ins
704704
AnyVectorOfAnyRank:$vector,
705-
Variadic<Index>:$dynamic_position,
705+
Variadic<AnySignlessIntegerOrIndex>:$dynamic_position,
706706
DenseI64ArrayAttr:$static_position
707707
);
708708
let results = (outs AnyType:$result);
@@ -737,7 +737,8 @@ def Vector_ExtractOp :
737737

738738
let assemblyFormat = [{
739739
$vector ``
740-
custom<DynamicIndexList>($dynamic_position, $static_position)
740+
custom<SameTypeDynamicIndexList>($dynamic_position, $static_position,
741+
type($dynamic_position))
741742
attr-dict `:` type($result) `from` type($vector)
742743
}];
743744

@@ -883,15 +884,15 @@ def Vector_InsertOp :
883884
%2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
884885
%5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
885886
%8 = vector.insert %6, %7[] : f32 into vector<f32>
886-
%11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
887-
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
887+
%11 = vector.insert %9, %10[%a, %b, %c : index] : vector<f32> into vector<4x8x16xf32>
888+
%12 = vector.insert %4, %10[2, %b : index] : vector<16xf32> into vector<4x8x16xf32>
888889
```
889890
}];
890891

891892
let arguments = (ins
892893
AnyType:$source,
893894
AnyVectorOfAnyRank:$dest,
894-
Variadic<Index>:$dynamic_position,
895+
Variadic<AnySignlessIntegerOrIndex>:$dynamic_position,
895896
DenseI64ArrayAttr:$static_position
896897
);
897898
let results = (outs AnyVectorOfAnyRank:$result);
@@ -926,7 +927,9 @@ def Vector_InsertOp :
926927
}];
927928

928929
let assemblyFormat = [{
929-
$source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
930+
$source `,` $dest
931+
custom<SameTypeDynamicIndexList>($dynamic_position, $static_position,
932+
type($dynamic_position))
930933
attr-dict `:` type($source) `into` type($dest)
931934
}];
932935

@@ -1344,7 +1347,7 @@ def Vector_TransferReadOp :
13441347
%a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
13451348
// Update the temporary gathered slice with the individual element
13461349
%slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
1347-
%updated = vector.insert %a, %slice[%i, %j, %k] : f32 into vector<3x4x5xf32>
1350+
%updated = vector.insert %a, %slice[%i, %j, %k : index] : f32 into vector<3x4x5xf32>
13481351
memref.store %updated, %tmp : memref<vector<3x4x5xf32>>
13491352
}}}
13501353
// At this point we gathered the elements from the original
@@ -1367,7 +1370,7 @@ def Vector_TransferReadOp :
13671370
%a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
13681371
%slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
13691372
// Here we only store to the first element in dimension one
1370-
%updated = vector.insert %a, %slice[%i, 0, %k] : f32 into vector<3x4x5xf32>
1373+
%updated = vector.insert %a, %slice[%i, 0, %k : index] : f32 into vector<3x4x5xf32>
13711374
memref.store %updated, %tmp : memref<vector<3x4x5xf32>>
13721375
}}
13731376
// At this point we gathered the elements from the original

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -794,16 +794,26 @@ class AsmParser {
794794
};
795795

796796
/// Parse a list of comma-separated items with an optional delimiter. If a
797-
/// delimiter is provided, then an empty list is allowed. If not, then at
797+
/// delimiter is provided, then an empty list is allowed. If not, then at
798798
/// least one element will be parsed.
799799
///
800+
/// `parseSuffixFn` is an optional function to parse any suffix that can be
801+
/// appended to the comma separated list within the delimiter.
802+
///
800803
/// contextMessage is an optional message appended to "expected '('" sorts of
801804
/// diagnostics when parsing the delimeters.
802-
virtual ParseResult
805+
virtual ParseResult parseCommaSeparatedList(
806+
Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
807+
std::optional<function_ref<ParseResult()>> parseSuffixFn = std::nullopt,
808+
StringRef contextMessage = StringRef()) = 0;
809+
ParseResult
803810
parseCommaSeparatedList(Delimiter delimiter,
804811
function_ref<ParseResult()> parseElementFn,
805-
StringRef contextMessage = StringRef()) = 0;
806-
812+
StringRef contextMessage) {
813+
return parseCommaSeparatedList(delimiter, parseElementFn,
814+
/*parseSuffixFn=*/std::nullopt,
815+
contextMessage);
816+
}
807817
/// Parse a comma separated list of elements that must have at least one entry
808818
/// in it.
809819
ParseResult
@@ -1319,6 +1329,9 @@ class AsmParser {
13191329
virtual ParseResult
13201330
parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
13211331

1332+
/// Parse an optional colon followed by a type.
1333+
virtual ParseResult parseOptionalColonType(Type &result) = 0;
1334+
13221335
/// Parse a keyword followed by a type.
13231336
ParseResult parseKeywordType(const char *keyword, Type &result) {
13241337
return failure(parseKeyword(keyword) || parseType(result));

mlir/include/mlir/Interfaces/ViewLikeInterface.h

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
9696
/// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes`
9797
/// is non-empty, it is expected to contain as many elements as `values`
9898
/// indicating their types. This allows idiomatic printing of mixed value and
99-
/// integer attributes in a list. E.g.
100-
/// `[%arg0 : index, 7, 42, %arg42 : i32]`.
99+
/// integer attributes in a list. E.g., `[%arg0 : index, 7, 42, %arg42 : i32]`.
100+
/// If `hasSameTypeDynamicValues` is `true`, `valueTypes` are expected to be the
101+
/// same and only one type is printed at the end of the list. E.g.,
102+
/// `[0, %arg2, 3, %arg42, 2 : i8]`.
101103
///
102104
/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable.
103105
/// This notation is similar to how scalable dims are marked when defining
@@ -108,7 +110,8 @@ void printDynamicIndexList(
108110
OpAsmPrinter &printer, Operation *op, OperandRange values,
109111
ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
110112
TypeRange valueTypes = TypeRange(),
111-
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
113+
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
114+
bool hasSameTypeDynamicValues = false);
112115
inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
113116
OperandRange values,
114117
ArrayRef<int64_t> integers,
@@ -123,6 +126,13 @@ inline void printDynamicIndexList(
123126
return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
124127
delimiter);
125128
}
129+
inline void printSameTypeDynamicIndexList(
130+
OpAsmPrinter &printer, Operation *op, OperandRange values,
131+
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
132+
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
133+
return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
134+
delimiter, /*hasSameTypeDynamicValues=*/true);
135+
}
126136

127137
/// Parser hook for custom directive in assemblyFormat.
128138
///
@@ -150,7 +160,8 @@ ParseResult parseDynamicIndexList(
150160
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
151161
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
152162
SmallVectorImpl<Type> *valueTypes = nullptr,
153-
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
163+
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
164+
bool hasSameTypeDynamicValues = false);
154165
inline ParseResult
155166
parseDynamicIndexList(OpAsmParser &parser,
156167
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
@@ -188,6 +199,16 @@ inline ParseResult parseDynamicIndexList(
188199
return parseDynamicIndexList(parser, values, integers, scalableVals,
189200
&valueTypes, delimiter);
190201
}
202+
inline ParseResult parseSameTypeDynamicIndexList(
203+
OpAsmParser &parser,
204+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
205+
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
206+
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
207+
DenseBoolArrayAttr scalableVals = {};
208+
return parseDynamicIndexList(parser, values, integers, scalableVals,
209+
&valueTypes, delimiter,
210+
/*hasSameTypeDynamicValues=*/true);
211+
}
191212

192213
/// Verify that a the `values` has as many elements as the number of entries in
193214
/// `attr` for which `isDynamic` evaluates to true.

mlir/lib/AsmParser/AsmParserImpl.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,16 @@ class AsmParserImpl : public BaseT {
340340
/// Parse a list of comma-separated items with an optional delimiter. If a
341341
/// delimiter is provided, then an empty list is allowed. If not, then at
342342
/// least one element will be parsed.
343-
ParseResult parseCommaSeparatedList(Delimiter delimiter,
344-
function_ref<ParseResult()> parseElt,
345-
StringRef contextMessage) override {
346-
return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
343+
ParseResult parseCommaSeparatedList(
344+
Delimiter delimiter, function_ref<ParseResult()> parseElt,
345+
std::optional<function_ref<ParseResult()>> parseSuffix,
346+
StringRef contextMessage) override {
347+
return parser.parseCommaSeparatedList(delimiter, parseElt, parseSuffix,
348+
contextMessage);
347349
}
348350

351+
using BaseT::parseCommaSeparatedList;
352+
349353
//===--------------------------------------------------------------------===//
350354
// Keyword Parsing
351355
//===--------------------------------------------------------------------===//
@@ -590,6 +594,17 @@ class AsmParserImpl : public BaseT {
590594
return parser.parseTypeListNoParens(result);
591595
}
592596

597+
/// Parse an optional colon followed by a type.
598+
ParseResult parseOptionalColonType(Type &result) override {
599+
SmallVector<Type, 1> types;
600+
ParseResult parseResult = parseOptionalColonTypeList(types);
601+
if (llvm::succeeded(parseResult) && types.size() > 1)
602+
return emitError(getCurrentLocation(), "expected single type");
603+
if (!types.empty())
604+
result = types[0];
605+
return parseResult;
606+
}
607+
593608
ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
594609
bool allowDynamic,
595610
bool withTrailingX) override {

mlir/lib/AsmParser/Parser.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ AsmParserCodeCompleteContext::~AsmParserCodeCompleteContext() = default;
8080
/// Parse a list of comma-separated items with an optional delimiter. If a
8181
/// delimiter is provided, then an empty list is allowed. If not, then at
8282
/// least one element will be parsed.
83-
ParseResult
84-
Parser::parseCommaSeparatedList(Delimiter delimiter,
85-
function_ref<ParseResult()> parseElementFn,
86-
StringRef contextMessage) {
83+
ParseResult Parser::parseCommaSeparatedList(
84+
Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
85+
std::optional<function_ref<ParseResult()>> parseSuffixFn,
86+
StringRef contextMessage) {
8787
switch (delimiter) {
8888
case Delimiter::None:
8989
break;
@@ -144,6 +144,9 @@ Parser::parseCommaSeparatedList(Delimiter delimiter,
144144
return failure();
145145
}
146146

147+
if (parseSuffixFn && (*parseSuffixFn)())
148+
return failure();
149+
147150
switch (delimiter) {
148151
case Delimiter::None:
149152
return success();

mlir/lib/AsmParser/Parser.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,17 @@ class Parser {
4646
/// Parse a list of comma-separated items with an optional delimiter. If a
4747
/// delimiter is provided, then an empty list is allowed. If not, then at
4848
/// least one element will be parsed.
49+
ParseResult parseCommaSeparatedList(
50+
Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
51+
std::optional<function_ref<ParseResult()>> parseSuffixFn = std::nullopt,
52+
StringRef contextMessage = StringRef());
4953
ParseResult
5054
parseCommaSeparatedList(Delimiter delimiter,
5155
function_ref<ParseResult()> parseElementFn,
52-
StringRef contextMessage = StringRef());
56+
StringRef contextMessage) {
57+
return parseCommaSeparatedList(delimiter, parseElementFn, std::nullopt,
58+
contextMessage);
59+
}
5360

5461
/// Parse a comma separated list of elements that must have at least one entry
5562
/// in it.

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -501,13 +501,14 @@ struct VectorOuterProductToArmSMELowering
501501
///
502502
/// Example:
503503
/// ```
504-
/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
504+
/// %el = vector.extract %tile[%row, %col : index] : i32 from
505+
/// vector<[4]x[4]xi32>
505506
/// ```
506507
/// Becomes:
507508
/// ```
508509
/// %slice = arm_sme.extract_tile_slice %tile[%row]
509510
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
510-
/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
511+
/// %el = vector.extract %slice[%col : index] : i32 from vector<[4]xi32>
511512
/// ```
512513
struct VectorExtractToArmSMELowering
513514
: public OpRewritePattern<vector::ExtractOp> {
@@ -561,8 +562,9 @@ struct VectorExtractToArmSMELowering
561562
/// ```
562563
/// %slice = arm_sme.extract_tile_slice %tile[%row]
563564
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
564-
/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
565-
/// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row]
565+
/// %new_slice = vector.insert %el, %slice[%col : index] : i32 into
566+
/// vector<[4]xi32> %new_tile = arm_sme.insert_tile_slice %new_slice,
567+
/// %tile[%row]
566568
/// : vector<[4]xi32> into vector<[4]x[4]xi32>
567569
/// ```
568570
struct VectorInsertToArmSMELowering

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,10 +1050,10 @@ getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
10501050
/// %vscale = vector.vscale
10511051
/// %c4_vscale = arith.muli %vscale, %c4 : index
10521052
/// scf.for %idx = %c0 to %c4_vscale step %c1 {
1053-
/// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
1054-
/// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
1055-
/// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
1056-
/// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
1053+
/// %4 = vector.extract %0[%idx : index] : f32 from vector<[4]xf32>
1054+
/// %5 = vector.extract %1[%idx : index] : f32 from vector<[4]xf32>
1055+
/// %6 = vector.extract %2[%idx : index] : f32 from vector<[4]xf32>
1056+
/// %7 = vector.extract %3[%idx : index] : f32 from vector<[4]xf32>
10571057
/// %slice_i = affine.apply #map(%idx)[%i]
10581058
/// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
10591059
/// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}

mlir/lib/Interfaces/ViewLikeInterface.cpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
114114
OperandRange values,
115115
ArrayRef<int64_t> integers,
116116
ArrayRef<bool> scalables, TypeRange valueTypes,
117-
AsmParser::Delimiter delimiter) {
117+
AsmParser::Delimiter delimiter,
118+
bool hasSameTypeDynamicValues) {
118119
char leftDelimiter = getLeftDelimiter(delimiter);
119120
char rightDelimiter = getRightDelimiter(delimiter);
120121
printer << leftDelimiter;
@@ -130,7 +131,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
130131
printer << "[";
131132
if (ShapedType::isDynamic(integer)) {
132133
printer << values[dynamicValIdx];
133-
if (!valueTypes.empty())
134+
if (!hasSameTypeDynamicValues && !valueTypes.empty())
134135
printer << " : " << valueTypes[dynamicValIdx];
135136
++dynamicValIdx;
136137
} else {
@@ -142,14 +143,22 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
142143
scalableIndexIdx++;
143144
});
144145

146+
if (hasSameTypeDynamicValues && !valueTypes.empty()) {
147+
assert(std::all_of(valueTypes.begin(), valueTypes.end(),
148+
[&](Type type) { return type == valueTypes[0]; }) &&
149+
"Expected the same value types");
150+
printer << " : " << valueTypes[0];
151+
}
152+
145153
printer << rightDelimiter;
146154
}
147155

148156
ParseResult mlir::parseDynamicIndexList(
149157
OpAsmParser &parser,
150158
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
151159
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables,
152-
SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
160+
SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter,
161+
bool hasSameTypeDynamicValues) {
153162

154163
SmallVector<int64_t, 4> integerVals;
155164
SmallVector<bool, 4> scalableVals;
@@ -163,7 +172,8 @@ ParseResult mlir::parseDynamicIndexList(
163172
if (res.has_value() && succeeded(res.value())) {
164173
values.push_back(operand);
165174
integerVals.push_back(ShapedType::kDynamic);
166-
if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
175+
if (!hasSameTypeDynamicValues && valueTypes &&
176+
parser.parseColonType(valueTypes->emplace_back()))
167177
return failure();
168178
} else {
169179
int64_t integer;
@@ -178,10 +188,34 @@ ParseResult mlir::parseDynamicIndexList(
178188
return failure();
179189
return success();
180190
};
191+
auto parseColonType = [&]() -> ParseResult {
192+
if (hasSameTypeDynamicValues) {
193+
assert(valueTypes && "Expected non-null value types");
194+
assert(valueTypes->empty() && "Expected no parsed value types");
195+
196+
Type dynValType;
197+
if (parser.parseOptionalColonType(dynValType))
198+
return failure();
199+
200+
if (!dynValType && !values.empty())
201+
return parser.emitError(parser.getNameLoc())
202+
<< "expected a type for dynamic indices";
203+
if (dynValType) {
204+
if (values.empty())
205+
return parser.emitError(parser.getNameLoc())
206+
<< "expected no type for constant indices";
207+
208+
// Broadcast the single type to all the dynamic values.
209+
valueTypes->append(values.size(), dynValType);
210+
}
211+
}
212+
return success();
213+
};
181214
if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
182-
" in dynamic index list"))
215+
parseColonType, " in dynamic index list"))
183216
return parser.emitError(parser.getNameLoc())
184-
<< "expected SSA value or integer";
217+
<< "expected a valid list of SSA values or integers";
218+
185219
integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
186220
scalables = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
187221
return success();

0 commit comments

Comments
 (0)