Skip to content

Commit 072da4b

Browse files
committed
Further doc updates per discussion with @banach-space
1 parent ce14c95 commit 072da4b

File tree

3 files changed

+39
-36
lines changed

3 files changed

+39
-36
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -696,27 +696,28 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
696696

697697
`D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
698698

699-
where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
700-
dimension identifiers (meant to range over valid indices), corresponding to
701-
the co-domains of the mandatory (projected permutation) `indexing_maps` of
702-
`A`, `B` and `C`, respectively. `SUM_{dims}` means reduce over all valid
703-
indices for the dimensions in the set `dims`.
699+
where `I`, `J`, and `H` are tuples of (pairwise distinct) dimension
700+
identifiers - meant to range over valid indices - corresponding to the
701+
results of the mandatory (projected permutation) `indexing_maps` for `A`,
702+
`B` and `C`. `SUM_{dims}` means reduce over all valid indices for the
703+
dimensions in the set `dims` (with `I`, `J`, and `K` treated as _sets_ of
704+
dim identifiers).
704705

705706
The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
706707
domain of each of the `affine_map`s. Like for einsums, the iteration type of
707708
each dim is inferred and is either:
708709

709-
- reduction: the dim occurs in (the multi-index of) `A` and `B` but not `C`.
710-
Per the above semantics, these dims will be contracted, i.e. reduced over.
710+
- reduction: the dim is used to index into `A` and `B` but not `C`. Per the
711+
above semantics, these dims will be contracted, i.e. reduced over.
711712

712-
- parallel: the dim occurs in `C` and at least one of `A` and `B`, and -
713-
deriving from matmul terminology - is either an "M-like" dim (if in `A`
714-
and `C`), an "N-like" dim (if in `B` and `C`) or a "batch"-dim (if in `A`,
715-
`B`, and `C`).
713+
- parallel: the dim is used to index into `C` and at least one of `A` and
714+
`B`, and - deriving from matmul terminology - is either an "M-like" dim
715+
(if used on `A` and `C`), an "N-like" dim (if used on `B` and `C`) or a
716+
"batch"-dim (if used to index into `A`, `B`, and `C`).
716717

717718
For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
718719
`H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
719-
`n` and `b` are of parallel iteration-type) and gets represented as:
720+
`n` and `b` have parallel iteration-type) and gets represented as:
720721

721722
```
722723
%D = linalg.contract
@@ -727,12 +728,11 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
727728
outs(%C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
728729
```
729730

730-
Note that by permuting dims in the co-domains of the `affine_map`s arbitrary
731-
transposes can be applied to the inputs and output. Similarly, arbitrary
732-
broadcasts can be achieved through leaving out dims on either input operand
733-
(these dims' inferred iter type will be parallel). For example, the
734-
following is a variant of batch-matmul where a transposition is applied to
735-
`A` while matrix `B` gets broadcasted along the batch dimension:
731+
Note that by permuting dims in the `affine_map`s' results, accesses to
732+
to the inputs and output can be arbitrarily transposed. Similarly, arbitrary
733+
broadcasts can be achieved through leaving out dims on either input operand.
734+
For example, the following is a variant of batch-matmul with a transposition
735+
applied to `A` while `B`'s 2D-matrix gets broadcasted along the batch dim:
736736

737737
```
738738
linalg.contract
@@ -744,7 +744,7 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
744744
```
745745

746746
Numeric casting is performed on the operands to the inner multiplication,
747-
promoting them to the same data type as the accumulator/output.
747+
promoting/truncating them to the same data type as the accumulator/output.
748748

749749
TODO: Allow control over the combining/accumulating op and possibly the
750750
multiplication op.
@@ -756,6 +756,9 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
756756
AffineMapArrayAttr:$indexing_maps
757757
);
758758
let results = (outs Variadic<AnyShaped>:$result_tensors);
759+
// NB: The only reason this op has a region - and it get populated at op build
760+
// time - is that currently the LinalgOp interface exposes methods that
761+
// assume a relevant region is available to be queried at any time.
759762
let regions = (region SizedRegion<1>:$combiner);
760763

761764
let skipDefaultBuilders = 1;

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3700,33 +3700,33 @@ LogicalResult ContractOp::verify() {
37003700
SmallVector<size_t> inOccurrences;
37013701
SmallVector<size_t> outOccurrences;
37023702

3703-
// For each operand's affine_map and type, check that the rank of the
3704-
// affine_map's domain is the same as those seen prior, check that the
3705-
// affine_map's co-domain rank is the same as that of the corresponding type,
3706-
// check that the affine_map is a projected permutation, and, finally, update
3707-
// inputs and output occurrence counts for dims in the co-domains.
3703+
// A helper so that for each operand's affine_map and type we check that ...
37083704
auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
37093705
bool isInput) -> LogicalResult {
3710-
if (iterationSpaceDims == -1) {
3711-
iterationSpaceDims = affineMap.getNumDims();
3712-
inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3713-
outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3714-
} else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
3715-
return emitError("iteration spaces of provided affine_maps differ");
3716-
}
3706+
// ... the affine_map is a projected permutation;
3707+
if (!affineMap.isProjectedPermutation())
3708+
return emitError("provided affine_map is not a projected permutation");
37173709

3710+
// ... the rank of the affine_map's results and corresponding type match;
37183711
if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
37193712
if (affineMap.getNumResults() != shapedType.getRank())
3720-
return emitError("ranks of shaped operand and co-domain of "
3721-
"corresponding affine_map differ");
3713+
return emitError("ranks of shaped operand and results of corresponding "
3714+
"affine_map differ");
37223715
} else if (affineMap.getNumResults() != 0) {
37233716
return emitError("affine_map specifies shaped access while operand has "
37243717
"non-shaped type");
37253718
}
37263719

3727-
if (!affineMap.isProjectedPermutation())
3728-
return emitError("provided affine_map is not a projected permutation");
3720+
// ... the rank of the affine_map's domain is the same as those seen prior;
3721+
if (iterationSpaceDims == -1) {
3722+
iterationSpaceDims = affineMap.getNumDims();
3723+
inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3724+
outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3725+
} else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
3726+
return emitError("iteration spaces of provided affine_maps differ");
3727+
}
37293728

3729+
// ... update counts of dims used to access either an input or the output.
37303730
for (AffineExpr affineExpr : affineMap.getResults()) {
37313731
auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
37323732
if (!affineDimExpr)

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ func.func @differing_iteration_space_of_affine_maps_contraction(
592592

593593
func.func @mismatched_ranks_affine_map_and_operand_contraction(
594594
%lhs: tensor<4x1x2xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
595-
// expected-error @+1 {{ranks of shaped operand and co-domain of corresponding affine_map differ}}
595+
// expected-error @+1 {{ranks of shaped operand and results of corresponding affine_map differ}}
596596
linalg.contract
597597
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
598598
affine_map<(d0, d1, d2) -> (d2, d1)>,

0 commit comments

Comments
 (0)