Skip to content

Commit cd04568

Browse files
committed
Further fixes per @banach-space's comments
1 parent 49557e8 commit cd04568

File tree

6 files changed

+33
-23
lines changed

6 files changed

+33
-23
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -698,9 +698,9 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
698698

699699
where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
700700
dimension identifiers (meant to range over valid indices), corresponding to
701-
the co-domains of the (projected permutation) `indexing_maps` of `A`, `B`
702-
and `C`, respectively. `SUM_{dims}` means reduce over all valid indices for
703-
the dimensions in the set `dims`.
701+
the co-domains of the mandatory (projected permutation) `indexing_maps` of
702+
`A`, `B` and `C`, respectively. `SUM_{dims}` means reduce over all valid
703+
indices for the dimensions in the set `dims`.
704704

705705
The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
706706
domain of each of the `affine_map`s. Like for einsums, the iteration type of
@@ -719,21 +719,24 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
719719
`n` and `b` are of parallel iteration-type) and gets represented as:
720720

721721
```
722-
%0 = linalg.contract
722+
%D = linalg.contract
723723
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
724724
affine_map<(batch, m, n, k) -> (batch, k, n)>,
725725
affine_map<(batch, m, n, k) -> (batch, m, n)>]
726-
ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
727-
outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
726+
ins(%A, %B: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
727+
outs(%C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
728728
```
729729

730730
Note that by permuting the dims in the co-domains of the `affine_map`s, we
731731
can apply arbitrary transposes to the inputs and output. Similarly,
732732
arbitrary broadcasts can be achieved through leaving out dims on either
733-
input operand.
733+
input operand - these dims' inferred iter type will be parallel.
734734

735735
Numeric casting is performed on the operands to the inner multiplication,
736736
promoting them to the same data type as the accumulator/output.
737+
738+
TODO: Allow control over the combining/accumulating op and possibly the
739+
multiplication op.
737740
}];
738741

739742
let arguments = (ins

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3693,10 +3693,18 @@ void ContractOp::print(OpAsmPrinter &p) {
36933693

36943694
LogicalResult ContractOp::verify() {
36953695
int iterationSpaceDims = -1;
3696-
// Maps iter space dim (as index) to num of occurrences in inputs and output.
3696+
// Map iter space dims to #occurrences in inputs' and output's affine_maps:
3697+
// e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
3698+
// access an input operand (so occurrence count can be at most 2) and
3699+
// outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
36973700
SmallVector<size_t> inOccurrences;
36983701
SmallVector<size_t> outOccurrences;
36993702

3703+
// For each operand's affine_map and type, check that the rank of the
3704+
// affine_map's domain is the same as those seen prior, check that the
3705+
// affine_map's co-domain rank is the same as that of the corresponding type,
3706+
// check that the affine_map is a projected permutation, and, finally, update
3707+
// inputs and output occurrence counts for dims in the co-domains.
37003708
auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
37013709
bool isInput) -> LogicalResult {
37023710
if (iterationSpaceDims == -1) {
@@ -3737,14 +3745,15 @@ LogicalResult ContractOp::verify() {
37373745
llvm::zip(getIndexingMapsArray(), getOperandTypes(),
37383746
SmallVector<bool>{true, true, false})) {
37393747
if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
3740-
return failure(); // NOTE: checking lambda will emit error.
3748+
return failure(); // NB: checkAffineMapAndType will emit relevant error.
37413749
}
37423750

37433751
bool hasContractingDim = false;
37443752
for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
37453753
size_t inOccCount = inOccurrences[dimIndex];
37463754
size_t outOccCount = outOccurrences[dimIndex];
37473755

3756+
// We have a contracting dim if and only if ...
37483757
hasContractingDim |= inOccCount == 2 && outOccCount == 0;
37493758

37503759
if (inOccCount == 0 && outOccCount == 0)

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,13 +1016,13 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
10161016
// CHECK-NEXT: arith.addf
10171017
// CHECK-NEXT: linalg.yield
10181018

1019-
func.func @contract_matmul(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
1019+
func.func @contract_matmul(%A: memref<3x5xf32>, %B: memref<5x7xf32>, %C: memref<3x7xf32>) {
10201020
linalg.contract
10211021
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
10221022
affine_map<(d0, d1, d2) -> (d2, d1)>,
10231023
affine_map<(d0, d1, d2) -> (d0, d1)>]
1024-
ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>)
1025-
outs(%arg2: memref<3x7xf32>)
1024+
ins(%A, %B : memref<3x5xf32>, memref<5x7xf32>)
1025+
outs(%C: memref<3x7xf32>)
10261026

10271027
return
10281028
}
@@ -1046,13 +1046,13 @@ func.func @contract_matmul(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2
10461046
// CHECK-NEXT: arith.addf
10471047
// CHECK-NEXT: linalg.yield
10481048

1049-
func.func @contract_matmul_transpose_a_b(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
1049+
func.func @contract_matmul_transpose_a_b(%A: memref<5x3xf32>, %B: memref<7x5xf32>, %C: memref<3x7xf32>) {
10501050
linalg.contract
10511051
indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
10521052
affine_map<(d0, d1, d2) -> (d1, d2)>,
10531053
affine_map<(d0, d1, d2) -> (d0, d1)>]
1054-
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
1055-
outs(%arg2: memref<3x7xf32>)
1054+
ins(%A, %B : memref<5x3xf32>, memref<7x5xf32>)
1055+
outs(%C: memref<3x7xf32>)
10561056
return
10571057
}
10581058

@@ -1075,13 +1075,13 @@ func.func @contract_matmul_transpose_a_b(%arg0: memref<5x3xf32>, %arg1: memref<7
10751075
// CHECK-NEXT: arith.addf
10761076
// CHECK-NEXT: linalg.yield
10771077

1078-
func.func @contract_batch_matmul(%arg0: memref<9x3x5xf32>, %arg1: memref<9x5x7xf32>, %arg2: memref<9x3x7xf32>) {
1078+
func.func @contract_batch_matmul(%A: memref<9x3x5xf32>, %B: memref<9x5x7xf32>, %C: memref<9x3x7xf32>) {
10791079
linalg.contract
10801080
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
10811081
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
10821082
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
1083-
ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
1084-
outs(%arg2: memref<9x3x7xf32>)
1083+
ins(%A, %B : memref<9x3x5xf32>, memref<9x5x7xf32>)
1084+
outs(%C: memref<9x3x7xf32>)
10851085
return
10861086
}
10871087

mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func.func @generalize_matmul_unsigned_tensor_f16f64i32(%A : tensor<16x8xf16>, %B
121121
// -----
122122

123123
func.func @generalize_matmul_as_contraction_tensor_f16f64f32(
124-
%A : tensor<16x8xf16>,
124+
%A: tensor<16x8xf16>,
125125
%B: tensor<8x32xf64>,
126126
%C: tensor<16x32xf32>) -> tensor<16x32xf32> {
127127
%0 = linalg.contract

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,8 @@ func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: t
549549

550550
func.func @invalid_indexing_maps_placement_contraction(
551551
%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
552-
// expected-error @+2 {{custom op 'linalg.contract' expected 'indexing_maps' attribute}}
552+
// expected-error @+3 {{custom op 'linalg.contract' expected 'indexing_maps' attribute}}
553+
// NB: indexing_maps should be provided before ins and outs
553554
linalg.contract
554555
ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
555556
outs(%init : tensor<4x64xf32>)

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,7 +1568,6 @@ func.func @contract_matmul_bcast_b(%A: memref<3x5xf32>, %B: memref<5xf32>, %C: m
15681568

15691569
// -----
15701570

1571-
15721571
// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2)>
15731572
// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
15741573
// CHECK-LABEL: func.func @contract_matmul_bcast_a_b
@@ -1608,7 +1607,6 @@ func.func @contract_matmul_bcast_a_transpose_b(
16081607
return
16091608
}
16101609

1611-
16121610
// -----
16131611

16141612
// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
@@ -1629,7 +1627,6 @@ func.func @contract_matmul_bcast_b_transpose_a(%A: memref<5x3xf32>, %B: memref<5
16291627
return
16301628
}
16311629

1632-
16331630
// -----
16341631

16351632
// CHECK-LABEL: func @mmt4d

0 commit comments

Comments
 (0)