@@ -680,6 +680,142 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
680680 }];
681681}
682682
683+ //===----------------------------------------------------------------------===//
684+ // Contract op.
685+ //===----------------------------------------------------------------------===//
686+
687+ def ContractOp : LinalgStructuredBase_Op<"contract", [
688+ AttrSizedOperandSegments,
689+ LinalgContractionOpInterface]> {
690+ let summary = [{
691+ Perform a contraction on two inputs, accumulating into the third.
692+ }];
693+ let description = [{
694+ The semantics of contracting inputs `A` and `B` on top of `C` to produce
695+ output `D` is given by
696+
697+ `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
698+
699+ where `I`, `J`, and `H` are tuples of (pairwise distinct) dimension
700+ identifiers - meant to range over valid indices - corresponding to the
701+ results of the mandatory (projected permutation) `indexing_maps` for `A`,
702+ `B` and `C`. `SUM_{dims}` means reduce over all valid indices for the
703+ dimensions in the set `dims` (with `I`, `J`, and `K` treated as _sets_ of
704+ dim identifiers).
705+
706+ The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
707+ domain of each of the `affine_map`s. Like for einsums, the iteration type of
708+ each dim is inferred and is either:
709+
710+ - reduction: the dim is used to index into `A` and `B` but not `C`. Per the
711+ above semantics, these dims will be contracted, i.e. reduced over.
712+
713+ - parallel: the dim is used to index into `C` and at least one of `A` and
714+ `B`, and - deriving from matmul terminology - is either an "M-like" dim
715+ (if used on `A` and `C`), an "N-like" dim (if used on `B` and `C`) or a
716+ "batch"-dim (if used to index into `A`, `B`, and `C`).
717+
718+ For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
719+ `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
720+ `n` and `b` have parallel iteration-type) and gets represented as:
721+
722+ ```
723+ %D = linalg.contract
724+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
725+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
726+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
727+ ins(%A, %B: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
728+ outs(%C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
729+ ```
730+
731+ Note that by permuting dims in the `affine_map`s' results, accesses to
732+ to the inputs and output can be arbitrarily transposed. Similarly, arbitrary
733+ broadcasts can be achieved through leaving out dims on either input operand.
734+ For example, the following is a variant of batch-matmul with a transposition
735+ applied to `A` while `B`'s 2D-matrix gets broadcasted along the batch dim:
736+
737+ ```
738+ linalg.contract
739+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>,
740+ affine_map<(batch, m, n, k) -> (k, n)>,
741+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
742+ ins(%A, %B: memref<?x?x?xf32>, memref<?x?xf32>)
743+ outs(%C: memref<?x?x?xf32>)
744+ ```
745+
746+ Numeric casting is performed on the operands to the inner multiplication,
747+ promoting/truncating them to the same data type as the accumulator/output.
748+
749+ TODO: Allow control over the combining/accumulating op and possibly the
750+ multiplication op.
751+ }];
752+
753+ let arguments = (ins
754+ Variadic<AnyType>:$inputs,
755+ Variadic<AnyShaped>:$outputs,
756+ AffineMapArrayAttr:$indexing_maps
757+ );
758+ let results = (outs Variadic<AnyShaped>:$result_tensors);
759+ // NB: The only reason this op has a region - and it get populated at op build
760+ // time - is that currently the LinalgOp interface exposes methods that
761+ // assume a relevant region is available to be queried at any time.
762+ let regions = (region SizedRegion<1>:$combiner);
763+
764+ let skipDefaultBuilders = 1;
765+ let builders = [
766+ OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
767+ "ValueRange":$outputs, "ArrayAttr":$indexingMaps,
768+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
769+ [{
770+ $_state.addAttribute("indexing_maps", indexingMaps);
771+ buildStructuredOp($_builder, $_state, resultTensorTypes, inputs,
772+ outputs, attributes, regionBuilder);
773+ }]>,
774+ OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
775+ "ArrayAttr":$indexingMaps,
776+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
777+ [{
778+ $_state.addAttribute("indexing_maps", indexingMaps);
779+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
780+ attributes, regionBuilder);
781+ }]>
782+ ];
783+ let hasCustomAssemblyFormat = 1;
784+ let hasFolder = 1;
785+ let hasVerifier = 1;
786+
787+ let extraClassDeclaration = structuredOpsBaseDecls # [{
788+ // Declare/implement functions necessary for LinalgStructuredInterface.
789+
790+ /// Infer iterator types for each dim in the domain of IndexingMaps.
791+ SmallVector<utils::IteratorType> getIteratorTypesArray();
792+
793+ /// IndexingMaps always depends on attr associated to current Op instance.
794+ bool hasDynamicIndexingMaps() { return true; };
795+ bool hasUserDefinedMaps() { return true; };
796+
797+ static unsigned getNumRegionArgs();
798+
799+ static void regionBuilder(ImplicitLocOpBuilder &b,
800+ Block &block, ArrayRef<NamedAttribute> attrs);
801+
802+ static std::function<void(ImplicitLocOpBuilder &,
803+ Block &, ArrayRef<NamedAttribute>)>
804+ getRegionBuilder() {
805+ return regionBuilder;
806+ }
807+
808+ std::string getLibraryCallName() {
809+ return "op_has_no_registered_library_name";
810+ }
811+
812+ // Implement function necessary for DestinationStyleOpInterface.
813+ ::mlir::MutableOperandRange getDpsInitsMutable() {
814+ return getOutputsMutable();
815+ }
816+ }];
817+ }
818+
683819//===----------------------------------------------------------------------===//
684820// Named Linalg ops, implemented as a declarative configurations of generic ops.
685821//===----------------------------------------------------------------------===//
0 commit comments