@@ -169,7 +169,8 @@ getDefaultIndexingMapsForMatmul(MLIRContext *context) {
169
169
}
170
170
171
171
// / Wrapper to return the typical indexing map array attribute for MatmulOp.
172
- static SmallVector<Attribute> getDefaultIndexingMapAttr (MLIRContext *context) {
172
+ static SmallVector<Attribute>
173
+ getDefaultMatmulIndexingMapAttr (MLIRContext *context) {
173
174
return llvm::map_to_vector (
174
175
getDefaultIndexingMapsForMatmul (context),
175
176
[](AffineMap map) -> Attribute { return AffineMapAttr::get (map); });
@@ -179,12 +180,11 @@ static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
179
180
// / The result types are derived automatically if `resultTensorTypes` is none.
180
181
// / The body of the operation is filled using `regionBuilder`. All ods-gen
181
182
// / created structured operations use the method to implement their builders.
182
- static void buildStructuredOp (
183
- OpBuilder &b, OperationState &state,
184
- std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
185
- ValueRange outputs, ArrayRef<NamedAttribute> attributes,
186
- RegionBuilderFn regionBuilder,
187
- std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
183
+ static void buildStructuredOp (OpBuilder &b, OperationState &state,
184
+ std::optional<TypeRange> resultTensorTypes,
185
+ ValueRange inputs, ValueRange outputs,
186
+ ArrayRef<NamedAttribute> attributes,
187
+ RegionBuilderFn regionBuilder) {
188
188
// Derive the result types if needed.
189
189
SmallVector<Type> derivedResultTypes =
190
190
resultTensorTypes.value_or (TypeRange ());
@@ -196,6 +196,24 @@ static void buildStructuredOp(
196
196
state.addOperands (outputs);
197
197
state.addTypes (derivedResultTypes);
198
198
199
+ state.addAttributes (attributes);
200
+ state.addAttribute (
201
+ " operandSegmentSizes" ,
202
+ b.getDenseI32ArrayAttr ({static_cast <int32_t >(inputs.size ()),
203
+ static_cast <int32_t >(outputs.size ())}));
204
+
205
+ // Create and fill the region of the structured operation.
206
+ Region ®ion = *state.addRegion ();
207
+ fillStructuredOpRegion (b, region, TypeRange (inputs), TypeRange (outputs),
208
+ state.attributes .getAttrs (), regionBuilder);
209
+ }
210
+
211
+ static void
212
+ buildMatmulOp (OpBuilder &b, OperationState &state,
213
+ std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
214
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes,
215
+ RegionBuilderFn regionBuilder,
216
+ std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
199
217
// Initialize indexingMaps, for MatmulOp.
200
218
SmallVector<Attribute, 3 > indexingMapsAttrVal;
201
219
if (indexingMaps.has_value ()) {
@@ -205,20 +223,11 @@ static void buildStructuredOp(
205
223
}
206
224
state.addAttribute (" indexing_maps" , b.getArrayAttr (indexingMapsAttrVal));
207
225
} else {
208
- indexingMapsAttrVal = getDefaultIndexingMapAttr (b.getContext ());
226
+ indexingMapsAttrVal = getDefaultMatmulIndexingMapAttr (b.getContext ());
209
227
state.addAttribute (" indexing_maps" , b.getArrayAttr (indexingMapsAttrVal));
210
228
}
211
-
212
- state.addAttributes (attributes);
213
- state.addAttribute (
214
- " operandSegmentSizes" ,
215
- b.getDenseI32ArrayAttr ({static_cast <int32_t >(inputs.size ()),
216
- static_cast <int32_t >(outputs.size ())}));
217
-
218
- // Create and fill the region of the structured operation.
219
- Region ®ion = *state.addRegion ();
220
- fillStructuredOpRegion (b, region, TypeRange (inputs), TypeRange (outputs),
221
- state.attributes .getAttrs (), regionBuilder);
229
+ return buildStructuredOp (b, state, resultTensorTypes, inputs, outputs,
230
+ attributes, regionBuilder);
222
231
}
223
232
224
233
// / Common parsing used for both named structured ops created by ods-gen and by
@@ -340,39 +349,6 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
340
349
OperationState &result,
341
350
unsigned numRegionArgs,
342
351
RegionBuilderFn regionBuilder) {
343
-
344
- SmallVector<Attribute, 3 > indexingMapsAttr;
345
- Attribute mapAttr;
346
- if (succeeded (parser.parseOptionalKeyword (" indexing_maps" ))) {
347
- if (parser.parseEqual ())
348
- return failure ();
349
-
350
- if (parser.parseLSquare ())
351
- return failure ();
352
-
353
- do {
354
- if (parser.parseAttribute (mapAttr))
355
- return failure ();
356
- if (!isa<AffineMapAttr>(mapAttr)) {
357
- return parser.emitError (parser.getCurrentLocation (),
358
- " expected affine map attribute" );
359
- }
360
- indexingMapsAttr.push_back (mapAttr);
361
-
362
- if (parser.parseOptionalComma ())
363
- break ;
364
- } while (true );
365
-
366
- if (parser.parseRSquare ())
367
- return failure ();
368
- }
369
- // Initialize indexingMaps, if not supplied explicitly.
370
- if (indexingMapsAttr.empty ()) {
371
- indexingMapsAttr = getDefaultIndexingMapAttr (result.getContext ());
372
- }
373
- result.addAttribute (" indexing_maps" ,
374
- parser.getBuilder ().getArrayAttr (indexingMapsAttr));
375
-
376
352
// TODO: Enable when ods-gen supports captures.
377
353
SmallVector<Type, 1 > inputTypes, outputTypes;
378
354
if (parseCommonStructuredOpParts (parser, result, inputTypes, outputTypes))
@@ -3503,9 +3479,11 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3503
3479
3504
3480
namespace mlir {
3505
3481
namespace linalg {
3482
+
3506
3483
// ===----------------------------------------------------------------------===//
3507
3484
// MatMulOp
3508
3485
// ===----------------------------------------------------------------------===//
3486
+
3509
3487
SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray () {
3510
3488
return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3511
3489
utils::IteratorType::parallel,
@@ -3520,8 +3498,8 @@ std::string MatmulOp::getLibraryCallName() {
3520
3498
3521
3499
bool MatmulOp::hasDynamicIndexingMaps () { return true ; }
3522
3500
3523
- // / Check if the op has broadcast and/or transpose semantic. Returns true if the
3524
- // / user defined indexing maps are not equal to default map.
3501
+ // / Check if the op has broadcast and/or transpose semantic. Returns true if
3502
+ // / the user defined indexing maps are not equal to default map.
3525
3503
bool MatmulOp::hasUserDefinedMaps () {
3526
3504
SmallVector<AffineMap, 3 > defaultMaps = getDefaultIndexingMaps ();
3527
3505
SmallVector<AffineMap, 3 > explicitMaps = getIndexingMapsArray ();
@@ -3557,7 +3535,8 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3557
3535
helper.yieldOutputs (yields);
3558
3536
}
3559
3537
3560
- // / Returns a list of AffineMap with the typical matmul indexing charactristic.
3538
+ // / Returns a list of AffineMap with the typical matmul indexing
3539
+ // / charactristic.
3561
3540
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps () {
3562
3541
MLIRContext *context = this ->getContext ();
3563
3542
return getDefaultIndexingMapsForMatmul (context);
@@ -3572,6 +3551,38 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3572
3551
}
3573
3552
3574
3553
ParseResult MatmulOp::parse (OpAsmParser &parser, OperationState &result) {
3554
+ SmallVector<Attribute, 3 > indexingMapsAttr;
3555
+ Attribute mapAttr;
3556
+ if (succeeded (parser.parseOptionalKeyword (" indexing_maps" ))) {
3557
+ if (parser.parseEqual ())
3558
+ return failure ();
3559
+
3560
+ if (parser.parseLSquare ())
3561
+ return failure ();
3562
+
3563
+ do {
3564
+ if (parser.parseAttribute (mapAttr))
3565
+ return failure ();
3566
+ if (!isa<AffineMapAttr>(mapAttr)) {
3567
+ return parser.emitError (parser.getCurrentLocation (),
3568
+ " expected affine map attribute" );
3569
+ }
3570
+ indexingMapsAttr.push_back (mapAttr);
3571
+
3572
+ if (parser.parseOptionalComma ())
3573
+ break ;
3574
+ } while (true );
3575
+
3576
+ if (parser.parseRSquare ())
3577
+ return failure ();
3578
+ }
3579
+ // Initialize indexingMaps, if not supplied explicitly.
3580
+ if (indexingMapsAttr.empty ()) {
3581
+ indexingMapsAttr = getDefaultMatmulIndexingMapAttr (result.getContext ());
3582
+ }
3583
+ result.addAttribute (" indexing_maps" ,
3584
+ parser.getBuilder ().getArrayAttr (indexingMapsAttr));
3585
+
3575
3586
return parseNamedStructuredOp (parser, result, MatmulOp::getNumRegionArgs (),
3576
3587
MatmulOp::getRegionBuilder ());
3577
3588
}
@@ -3582,7 +3593,7 @@ void MatmulOp::print(OpAsmPrinter &p) {
3582
3593
elidedAttrs);
3583
3594
3584
3595
SmallVector<Attribute, 3 > indexingMaps =
3585
- getDefaultIndexingMapAttr (getContext ());
3596
+ getDefaultMatmulIndexingMapAttr (getContext ());
3586
3597
if (!llvm::equal (getIndexingMaps (), indexingMaps)) {
3587
3598
p << " indexing_maps = [" ;
3588
3599
llvm::interleaveComma (getIndexingMaps (), p,
0 commit comments