@@ -564,7 +564,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
564564
565565 let summary = [{
566566 Performs a matrix multiplication of two 2D inputs without broadcast or transpose.
567- }];
567+ }];
568568 let description = [{
569569 Numeric casting is performed on the operands to the inner multiply,
570570 promoting them to the same data type as the accumulator/output.
@@ -604,83 +604,83 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
604604 ]
605605 ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
606606 ```
607- }];
608-
609- let arguments = (ins
610- Variadic<AnyType>:$inputs,
611- Variadic<AnyShaped>:$outputs,
612- DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
613- DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
614- );
615- let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
616- let regions = (region AnyRegion:$region);
617-
618- let skipDefaultBuilders = 1;
619- let builders = [
620- OpBuilder<
621- (ins "ValueRange":$inputs, "ValueRange":$outputs,
622- CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
623- [{
624- buildMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
625- attributes, MatmulOp::getRegionBuilder(),
626- MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
627- }]>,
628- OpBuilder<
629- (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
630- "ValueRange":$outputs,
631- CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
632- [{
633- buildMatmulOp($_builder, $_state, resultTensorTypes,
634- inputs, outputs, attributes, MatmulOp::getRegionBuilder(),
635- MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
636- }]>,
637- OpBuilder<
638- (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
639- "ValueRange":$outputs,
640- "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
641- [{
642- $_state.addAttribute("cast", cast);
643- buildMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
644- attributes, MatmulOp::getRegionBuilder(),
645- MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
646- }]>
647-
648- ];
649- let hasCustomAssemblyFormat = 1;
650- let hasFolder = 1;
651- let hasVerifier = 1;
652-
653- let extraClassDeclaration = structuredOpsBaseDecls # [{
654- SmallVector<utils::IteratorType> getIteratorTypesArray();
655-
656- /// Implements the block region builder.
657- static void regionBuilder(ImplicitLocOpBuilder &b,
658- Block &block, ArrayRef<NamedAttribute> attrs);
659-
660- /// Returns a list of AffineMap with the typical matmul indexing charactristic.
661- static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
662-
663- /// Returns true if the given broadcast map \p bcastMap is valid for this op.
664- bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
665-
666- static std::function<void(ImplicitLocOpBuilder &,
667- Block &, ArrayRef<NamedAttribute>)>
668- getRegionBuilder() {
669- return regionBuilder;
670- }
607+ }];
671608
672- ::mlir::MutableOperandRange getDpsInitsMutable() {
673- return getOutputsMutable();
674- }
609+ let arguments = (ins
610+ Variadic<AnyType>:$inputs,
611+ Variadic<AnyShaped>:$outputs,
612+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
613+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
614+ );
615+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
616+ let regions = (region AnyRegion:$region);
617+
618+ let skipDefaultBuilders = 1;
619+ let builders = [
620+ OpBuilder<
621+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
622+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
623+ [{
624+ buildMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
625+ attributes, MatmulOp::getRegionBuilder(),
626+ MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
627+ }]>,
628+ OpBuilder<
629+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
630+ "ValueRange":$outputs,
631+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
632+ [{
633+ buildMatmulOp($_builder, $_state, resultTensorTypes,
634+ inputs, outputs, attributes, MatmulOp::getRegionBuilder(),
635+ MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
636+ }]>,
637+ OpBuilder<
638+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
639+ "ValueRange":$outputs,
640+ "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
641+ [{
642+ $_state.addAttribute("cast", cast);
643+ buildMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
644+ attributes, MatmulOp::getRegionBuilder(),
645+ MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
646+ }]>
647+
648+ ];
649+ let hasCustomAssemblyFormat = 1;
650+ let hasFolder = 1;
651+ let hasVerifier = 1;
652+
653+ let extraClassDeclaration = structuredOpsBaseDecls # [{
654+ SmallVector<utils::IteratorType> getIteratorTypesArray();
655+
656+ /// Implements the block region builder.
657+ static void regionBuilder(ImplicitLocOpBuilder &b,
658+ Block &block, ArrayRef<NamedAttribute> attrs);
659+
660+ /// Returns a list of AffineMap with the typical matmul indexing characteristic.
661+ static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
662+
663+ /// Returns true if the given broadcast map \p bcastMap is valid for this op.
664+ bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
665+
666+ static std::function<void(ImplicitLocOpBuilder &,
667+ Block &, ArrayRef<NamedAttribute>)>
668+ getRegionBuilder() {
669+ return regionBuilder;
670+ }
671+
672+ ::mlir::MutableOperandRange getDpsInitsMutable() {
673+ return getOutputsMutable();
674+ }
675675
676- // Generic methods.
677- static unsigned getNumRegionArgs();
678- std::string getLibraryCallName();
679- bool hasDynamicIndexingMaps();
680- /// Check if the op has broadcast and/or transpose semantic. Returns true if the
681- /// user defined indexing maps are not equal to default map.
682- bool hasUserDefinedMaps();
683- }];
676+ // Generic methods.
677+ static unsigned getNumRegionArgs();
678+ std::string getLibraryCallName();
679+ bool hasDynamicIndexingMaps();
680+ /// Check if the op has broadcast and/or transpose semantic. Returns true if the
681+ /// user defined indexing maps are not equal to default map.
682+ bool hasUserDefinedMaps();
683+ }];
684684}
685685
686686//===----------------------------------------------------------------------===//
0 commit comments