Skip to content

Commit e4df55d

Browse files
committed
Address @adam-smnk's comments
1 parent b497f54 commit e4df55d

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,7 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
769769

770770
let extraClassDeclaration = structuredOpsBaseDecls # [{
771771
// Declare/implement functions necessary for LinalgStructuredInterface.
772+
772773
/// Infer iterator types for each dim in the domain of IndexingMaps.
773774
SmallVector<utils::IteratorType> getIteratorTypesArray();
774775

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3746,22 +3746,29 @@ LogicalResult ContractOp::verify() {
37463746
return failure(); // NOTE: checking lambda will emit error.
37473747

37483748
bool hasContractingDim = false;
3749-
for (auto &&[inOccCount, outOccCount] : zip(inOccurrences, outOccurrences)) {
3749+
for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
3750+
size_t inOccCount = inOccurrences[dimIndex];
3751+
size_t outOccCount = outOccurrences[dimIndex];
3752+
37503753
hasContractingDim |= inOccCount == 2 && outOccCount == 0;
37513754

37523755
if (inOccCount == 0)
3753-
return emitError("iteration space dim not used by either input");
3754-
3755-
// NB: A dim which occurs for only one input operand and not for the output.
3756-
// In terms of einsum semantics, such dims have a sensible meaning -
3757-
// namely an additional reduction per such dim - though this can also
3758-
// always be expressed through an additional op. Additionally, at time
3759-
// of writing, vector.contract's verifier accepts these dims but many of
3760-
// its lowerings do not handle these kinds of dims. Hence...
3756+
return emitError() << "iteration space dim at index " << dimIndex
3757+
<< " not used by either input";
3758+
3759+
// NB: We disallow a dim which occurs for only one input operand and not
3760+
// for the output. In terms of einsum semantics such dims have a
3761+
// sensible meaning - namely an additional reduction per each such dim.
3762+
// By contrast, the ContractionOpInterface does not know about this
3763+
// iter type - cf. inferContractionDims' supported dim kinds. Similarly,
3764+
// while vector.contract's verifier accepts dims of this kind many of
3765+
// its lowerings give up on encountering these dims.
37613766
// TODO: Remove following once we have comprehensive support for input-only
37623767
// reduction dims, at both the linalg- and vector-dialect levels.
37633768
if (inOccCount == 1 && outOccCount != 1)
3764-
return emitError("iter type of dim is not one of M, N, K or batch");
3769+
return emitError()
3770+
<< "iteration space dim at index " << dimIndex
3771+
<< " is neither a contracting dim nor of parallel iteration type";
37653772
}
37663773

37673774
if (!hasContractingDim)

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ func.func @mismatch_type_affine_map_and_operand_contraction(%lhs: f32, %rhs: ten
617617
// -----
618618

619619
func.func @unused_iteration_space_dim_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
620-
// expected-error @+1 {{iteration space dim not used by either input}}
620+
// expected-error @+1 {{iteration space dim at index 3 not used by either input}}
621621
linalg.contract indexing_maps = [
622622
affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
623623
affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
@@ -631,7 +631,7 @@ func.func @unused_iteration_space_dim_contraction(%lhs: tensor<4x1xf32>, %rhs: t
631631
// -----
632632

633633
func.func @unused_iteration_space_dim_contraction(%lhs: tensor<8x4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
634-
// expected-error @+1 {{iter type of dim is not one of M, N, K or batch}}
634+
// expected-error @+1 {{iteration space dim at index 3 is neither a contracting dim nor of parallel iteration type}}
635635
linalg.contract indexing_maps = [
636636
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
637637
affine_map<(d0, d1, d2, d3) -> (d2, d1)>,

0 commit comments

Comments
 (0)