@@ -169,7 +169,8 @@ getDefaultIndexingMapsForMatmul(MLIRContext *context) {
169169}
170170
171171// / Wrapper to return the typical indexing map array attribute for MatmulOp.
172- static SmallVector<Attribute> getDefaultIndexingMapAttr (MLIRContext *context) {
172+ static SmallVector<Attribute>
173+ getDefaultMatmulIndexingMapAttr (MLIRContext *context) {
173174 return llvm::map_to_vector (
174175 getDefaultIndexingMapsForMatmul (context),
175176 [](AffineMap map) -> Attribute { return AffineMapAttr::get (map); });
@@ -179,12 +180,11 @@ static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
179180// / The result types are derived automatically if `resultTensorTypes` is none.
180181// / The body of the operation is filled using `regionBuilder`. All ods-gen
181182// / created structured operations use the method to implement their builders.
182- static void buildStructuredOp (
183- OpBuilder &b, OperationState &state,
184- std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
185- ValueRange outputs, ArrayRef<NamedAttribute> attributes,
186- RegionBuilderFn regionBuilder,
187- std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt ) {
183+ static void buildStructuredOp (OpBuilder &b, OperationState &state,
184+ std::optional<TypeRange> resultTensorTypes,
185+ ValueRange inputs, ValueRange outputs,
186+ ArrayRef<NamedAttribute> attributes,
187+ RegionBuilderFn regionBuilder) {
188188 // Derive the result types if needed.
189189 SmallVector<Type> derivedResultTypes =
190190 resultTensorTypes.value_or (TypeRange ());
@@ -196,6 +196,24 @@ static void buildStructuredOp(
196196 state.addOperands (outputs);
197197 state.addTypes (derivedResultTypes);
198198
199+ state.addAttributes (attributes);
200+ state.addAttribute (
201+ " operandSegmentSizes" ,
202+ b.getDenseI32ArrayAttr ({static_cast <int32_t >(inputs.size ()),
203+ static_cast <int32_t >(outputs.size ())}));
204+
205+ // Create and fill the region of the structured operation.
206+ Region ®ion = *state.addRegion ();
207+ fillStructuredOpRegion (b, region, TypeRange (inputs), TypeRange (outputs),
208+ state.attributes .getAttrs (), regionBuilder);
209+ }
210+
211+ static void
212+ buildMatmulOp (OpBuilder &b, OperationState &state,
213+ std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
214+ ValueRange outputs, ArrayRef<NamedAttribute> attributes,
215+ RegionBuilderFn regionBuilder,
216+ std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt ) {
199217 // Initialize indexingMaps, for MatmulOp.
200218 SmallVector<Attribute, 3 > indexingMapsAttrVal;
201219 if (indexingMaps.has_value ()) {
@@ -205,20 +223,11 @@ static void buildStructuredOp(
205223 }
206224 state.addAttribute (" indexing_maps" , b.getArrayAttr (indexingMapsAttrVal));
207225 } else {
208- indexingMapsAttrVal = getDefaultIndexingMapAttr (b.getContext ());
226+ indexingMapsAttrVal = getDefaultMatmulIndexingMapAttr (b.getContext ());
209227 state.addAttribute (" indexing_maps" , b.getArrayAttr (indexingMapsAttrVal));
210228 }
211-
212- state.addAttributes (attributes);
213- state.addAttribute (
214- " operandSegmentSizes" ,
215- b.getDenseI32ArrayAttr ({static_cast <int32_t >(inputs.size ()),
216- static_cast <int32_t >(outputs.size ())}));
217-
218- // Create and fill the region of the structured operation.
219- Region ®ion = *state.addRegion ();
220- fillStructuredOpRegion (b, region, TypeRange (inputs), TypeRange (outputs),
221- state.attributes .getAttrs (), regionBuilder);
229+ return buildStructuredOp (b, state, resultTensorTypes, inputs, outputs,
230+ attributes, regionBuilder);
222231}
223232
224233// / Common parsing used for both named structured ops created by ods-gen and by
@@ -340,39 +349,6 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
340349 OperationState &result,
341350 unsigned numRegionArgs,
342351 RegionBuilderFn regionBuilder) {
343-
344- SmallVector<Attribute, 3 > indexingMapsAttr;
345- Attribute mapAttr;
346- if (succeeded (parser.parseOptionalKeyword (" indexing_maps" ))) {
347- if (parser.parseEqual ())
348- return failure ();
349-
350- if (parser.parseLSquare ())
351- return failure ();
352-
353- do {
354- if (parser.parseAttribute (mapAttr))
355- return failure ();
356- if (!isa<AffineMapAttr>(mapAttr)) {
357- return parser.emitError (parser.getCurrentLocation (),
358- " expected affine map attribute" );
359- }
360- indexingMapsAttr.push_back (mapAttr);
361-
362- if (parser.parseOptionalComma ())
363- break ;
364- } while (true );
365-
366- if (parser.parseRSquare ())
367- return failure ();
368- }
369- // Initialize indexingMaps, if not supplied explicitly.
370- if (indexingMapsAttr.empty ()) {
371- indexingMapsAttr = getDefaultIndexingMapAttr (result.getContext ());
372- }
373- result.addAttribute (" indexing_maps" ,
374- parser.getBuilder ().getArrayAttr (indexingMapsAttr));
375-
376352 // TODO: Enable when ods-gen supports captures.
377353 SmallVector<Type, 1 > inputTypes, outputTypes;
378354 if (parseCommonStructuredOpParts (parser, result, inputTypes, outputTypes))
@@ -3503,9 +3479,11 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
35033479
35043480namespace mlir {
35053481namespace linalg {
3482+
35063483// ===----------------------------------------------------------------------===//
35073484// MatMulOp
35083485// ===----------------------------------------------------------------------===//
3486+
35093487SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray () {
35103488 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
35113489 utils::IteratorType::parallel,
@@ -3520,8 +3498,8 @@ std::string MatmulOp::getLibraryCallName() {
35203498
35213499bool MatmulOp::hasDynamicIndexingMaps () { return true ; }
35223500
3523- // / Check if the op has broadcast and/or transpose semantic. Returns true if the
3524- // / user defined indexing maps are not equal to default map.
3501+ // / Check if the op has broadcast and/or transpose semantic. Returns true if
3502+ // / the user defined indexing maps are not equal to default map.
35253503bool MatmulOp::hasUserDefinedMaps () {
35263504 SmallVector<AffineMap, 3 > defaultMaps = getDefaultIndexingMaps ();
35273505 SmallVector<AffineMap, 3 > explicitMaps = getIndexingMapsArray ();
@@ -3557,7 +3535,8 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
35573535 helper.yieldOutputs (yields);
35583536}
35593537
3560- // / Returns a list of AffineMap with the typical matmul indexing charactristic.
3538+ // / Returns a list of AffineMap with the typical matmul indexing
3539+ // / charactristic.
35613540SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps () {
35623541 MLIRContext *context = this ->getContext ();
35633542 return getDefaultIndexingMapsForMatmul (context);
@@ -3572,6 +3551,38 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
35723551}
35733552
35743553ParseResult MatmulOp::parse (OpAsmParser &parser, OperationState &result) {
3554+ SmallVector<Attribute, 3 > indexingMapsAttr;
3555+ Attribute mapAttr;
3556+ if (succeeded (parser.parseOptionalKeyword (" indexing_maps" ))) {
3557+ if (parser.parseEqual ())
3558+ return failure ();
3559+
3560+ if (parser.parseLSquare ())
3561+ return failure ();
3562+
3563+ do {
3564+ if (parser.parseAttribute (mapAttr))
3565+ return failure ();
3566+ if (!isa<AffineMapAttr>(mapAttr)) {
3567+ return parser.emitError (parser.getCurrentLocation (),
3568+ " expected affine map attribute" );
3569+ }
3570+ indexingMapsAttr.push_back (mapAttr);
3571+
3572+ if (parser.parseOptionalComma ())
3573+ break ;
3574+ } while (true );
3575+
3576+ if (parser.parseRSquare ())
3577+ return failure ();
3578+ }
3579+ // Initialize indexingMaps, if not supplied explicitly.
3580+ if (indexingMapsAttr.empty ()) {
3581+ indexingMapsAttr = getDefaultMatmulIndexingMapAttr (result.getContext ());
3582+ }
3583+ result.addAttribute (" indexing_maps" ,
3584+ parser.getBuilder ().getArrayAttr (indexingMapsAttr));
3585+
35753586 return parseNamedStructuredOp (parser, result, MatmulOp::getNumRegionArgs (),
35763587 MatmulOp::getRegionBuilder ());
35773588}
@@ -3582,7 +3593,7 @@ void MatmulOp::print(OpAsmPrinter &p) {
35823593 elidedAttrs);
35833594
35843595 SmallVector<Attribute, 3 > indexingMaps =
3585- getDefaultIndexingMapAttr (getContext ());
3596+ getDefaultMatmulIndexingMapAttr (getContext ());
35863597 if (!llvm::equal (getIndexingMaps (), indexingMaps)) {
35873598 p << " indexing_maps = [" ;
35883599 llvm::interleaveComma (getIndexingMaps (), p,
0 commit comments