Skip to content

Commit d41bbe7

Browse files
committed
Address Adam's comments, round 2
1 parent e4df55d commit d41bbe7

File tree

4 files changed

+136
-11
lines changed

4 files changed

+136
-11
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3622,15 +3622,15 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
36223622

36233623
SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
36243624
AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3625-
/// On well-formed IR, indexing_maps is non-empty, contained affine_maps'
3626-
/// domains are all the same, and each implements a projected permutation.
3627-
/// Each dim in the domain must occur for at least one operand and is
3628-
/// classified as either batch, N-like, M-like, or K-like. Only the latter
3629-
/// corresponds to a reduction _and_ it is the only dim-kind which does not
3630-
/// occur for the output operand. We use this fact for fast inference:
3625+
// On well-formed IR, indexing_maps is non-empty, contained affine_maps'
3626+
// domains are all the same, and each implements a projected permutation.
3627+
// Each iteration space dim must occur for at least one operand and either
3628+
// takes part in a contraction/reduction or else has parallel iteration type.
3629+
// We have that a dim is a contraction/reduction dim if and only if the dim
3630+
// occurs for the output operand. We use this fact for fast inference:
36313631
// NB: In case we allow dims to occur solely for one input, the above still
36323632
// holds: per the einsum semantics, these are reduction dims as well.
3633-
auto dimsInOutput = SmallVector<bool>(outAffineMap.getNumDims(), false);
3633+
SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
36343634
for (auto result : outAffineMap.getResults()) {
36353635
auto dimExpr = dyn_cast<AffineDimExpr>(result);
36363636
assert(dimExpr && "affine_map is a projected permutation");
@@ -3741,9 +3741,10 @@ LogicalResult ContractOp::verify() {
37413741

37423742
for (auto &&[affineMap, operandType, isInput] :
37433743
llvm::zip(getIndexingMapsArray(), getOperandTypes(),
3744-
SmallVector<bool>{true, true, false}))
3744+
SmallVector<bool>{true, true, false})) {
37453745
if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
37463746
return failure(); // NOTE: checking lambda will emit error.
3747+
}
37473748

37483749
bool hasContractingDim = false;
37493750
for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
@@ -3752,9 +3753,9 @@ LogicalResult ContractOp::verify() {
37523753

37533754
hasContractingDim |= inOccCount == 2 && outOccCount == 0;
37543755

3755-
if (inOccCount == 0)
3756+
if (inOccCount == 0 && outOccCount == 0)
37563757
return emitError() << "iteration space dim at index " << dimIndex
3757-
<< " not used by either input";
3758+
<< " not used to access any operand";
37583759

37593760
// NB: We disallow a dim which occurs for only one input operand and not
37603761
// for the output. In terms of einsum semantics such dims have a

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,3 +1164,28 @@ func.func @contract_dot(%arg0: memref<9xf32>, %arg1: memref<9xf32>, %arg2: memre
11641164
outs(%arg2: memref<f32>)
11651165
return
11661166
}
1167+
1168+
// -----
1169+
1170+
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1171+
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1172+
1173+
// CHECK-LABEL: func.func @contract_matmul_bcast_a_b
1174+
// CHECK-SAME: (%[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>, %[[VAL_2:.*]]: memref<3x7xf32>) {
1175+
1176+
// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
1177+
// CHECK-NEXT: ^{{.+}}(
1178+
// CHECK-NEXT: arith.mulf
1179+
// CHECK-NEXT: arith.addf
1180+
// CHECK-NEXT: linalg.yield
1181+
1182+
func.func @contract_matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
1183+
linalg.contract indexing_maps = [
1184+
affine_map<(d0, d1, d2) -> (d2)>,
1185+
affine_map<(d0, d1, d2) -> (d2)>,
1186+
affine_map<(d0, d1, d2) -> (d0, d1)>
1187+
]
1188+
ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>)
1189+
outs(%arg2: memref<3x7xf32>)
1190+
return
1191+
}

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ func.func @mismatch_type_affine_map_and_operand_contraction(%lhs: f32, %rhs: ten
617617
// -----
618618

619619
func.func @unused_iteration_space_dim_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
620-
// expected-error @+1 {{iteration space dim at index 3 not used by either input}}
620+
// expected-error @+1 {{iteration space dim at index 3 not used for any operand}}
621621
linalg.contract indexing_maps = [
622622
affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
623623
affine_map<(d0, d1, d2, d3) -> (d2, d1)>,

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,6 +1528,105 @@ func.func @contract(%arg0: memref<2x3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: m
15281528

15291529
// -----
15301530

1531+
func.func @contract_matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
1532+
linalg.contract indexing_maps = [
1533+
affine_map<(d0, d1, d2) -> (d2)>,
1534+
affine_map<(d0, d1, d2) -> (d2, d1)>,
1535+
affine_map<(d0, d1, d2) -> (d0, d1)>
1536+
]
1537+
ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
1538+
return
1539+
}
1540+
1541+
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1542+
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
1543+
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1544+
// CHECK-LABEL: func @contract_matmul_bcast_a
1545+
// CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1546+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
1547+
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
1548+
1549+
// -----
1550+
1551+
func.func @contract_matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
1552+
linalg.contract indexing_maps = [
1553+
affine_map<(d0, d1, d2) -> (d0, d2)>,
1554+
affine_map<(d0, d1, d2) -> (d2)>,
1555+
affine_map<(d0, d1, d2) -> (d0, d1)>
1556+
]
1557+
ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
1558+
return
1559+
}
1560+
1561+
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
1562+
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1563+
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1564+
// CHECK-LABEL: func @contract_matmul_bcast_b
1565+
// CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1566+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
1567+
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
1568+
1569+
// -----
1570+
1571+
func.func @contract_matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
1572+
linalg.contract indexing_maps = [
1573+
affine_map<(d0, d1, d2) -> (d2)>,
1574+
affine_map<(d0, d1, d2) -> (d2)>,
1575+
affine_map<(d0, d1, d2) -> (d0, d1)>
1576+
]
1577+
ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
1578+
return
1579+
}
1580+
1581+
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1582+
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1583+
// CHECK-LABEL: func.func @contract_matmul_bcast_a_b
1584+
// CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
1585+
// CHECK: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5xf32>)
1586+
// CHECK: outs(%{{.+}} : memref<3x7xf32>)
1587+
1588+
// -----
1589+
1590+
func.func @contract_matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
1591+
linalg.contract indexing_maps = [
1592+
affine_map<(d0, d1, d2) -> (d2)>,
1593+
affine_map<(d0, d1, d2) -> (d1, d2)>,
1594+
affine_map<(d0, d1, d2) -> (d0, d1)>
1595+
]
1596+
ins(%arg0, %arg1 : memref<5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
1597+
return
1598+
}
1599+
1600+
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1601+
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
1602+
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1603+
// CHECK-LABEL: func.func @contract_matmul_bcast_a_transpose_b
1604+
// CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1605+
// CHECK: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<7x5xf32>)
1606+
// CHECK: outs(%{{.+}} : memref<3x7xf32>)
1607+
1608+
// -----
1609+
1610+
func.func @contract_matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
1611+
linalg.contract indexing_maps = [
1612+
affine_map<(d0, d1, d2) -> (d2, d0)>,
1613+
affine_map<(d0, d1, d2) -> (d2)>,
1614+
affine_map<(d0, d1, d2) -> (d0, d1)>
1615+
]
1616+
ins(%arg0, %arg1 : memref<5x3xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
1617+
return
1618+
}
1619+
1620+
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
1621+
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1622+
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1623+
// CHECK-LABEL: func.func @contract_matmul_bcast_b_transpose_a
1624+
// CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1625+
// CHECK: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5xf32>)
1626+
// CHECK: outs(%{{.+}} : memref<3x7xf32>)
1627+
1628+
// -----
1629+
15311630
// CHECK-LABEL: func @mmt4d
15321631
func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> {
15331632
// CHECK: %{{.+}} = linalg.mmt4d

0 commit comments

Comments
 (0)