Skip to content

Commit b27d97b

Browse files
authored
[mlir][tensor] Add test for invalid tensor.unpack + update error msg (#118275)
Adds a new test for invalid `tensor.unpack` operations where the output rank does not match the expected rank (input rank + num inner tile sizes). For example: ```mlir tensor.unpack %output inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %input : tensor<64x32x16xf32> -> tensor<256x128xf32> ``` In addition, updates the corresponding error message to make it more informative: BEFORE: ```mlir error: packed rank must equal unpacked rank + tiling factors} ``` AFTER: ```mlir error: packed rank != (unpacked rank + num tiling factors), got 3 != 4 ```
1 parent 608f4ae commit b27d97b

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3983,9 +3983,11 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
39833983
: packOrUnPack.getSourceType();
39843984
size_t packedRank = packedType.getRank();
39853985
// Require output rank to match input rank + number of blocking factors.
3986-
if (unpackedRank + mixedTiles.size() != packedRank) {
3986+
size_t expectedPackedRank = unpackedRank + mixedTiles.size();
3987+
if (expectedPackedRank != packedRank) {
39873988
return op->emitError(
3988-
"packed rank must equal unpacked rank + tiling factors");
3989+
"packed rank != (unpacked rank + num tiling factors), got ")
3990+
<< packedRank << " != " << expectedPackedRank;
39893991
}
39903992

39913993
// Verify result shape is greater than the minimum expected

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt <%s -split-input-file -verify-diagnostics
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
22

33
// Asking the dimension of a 0-D shape doesn't make sense.
44
func.func @dim_0_ranked(%arg : tensor<f32>, %arg1 : index) {
@@ -692,13 +692,21 @@ func.func @pack_invalid_duplicate_element_in_outer_perm(%input: tensor<256x128xf
692692
// -----
693693

694694
func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<64x32x16xf32> {
695-
// expected-error@+1 {{packed rank must equal unpacked rank + tiling factors}}
695+
// expected-error@+1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}}
696696
%0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<64x32x16xf32>
697697
return %0 : tensor<64x32x16xf32>
698698
}
699699

700700
// -----
701701

702+
func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
703+
// expected-error@+1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}}
704+
%0 = tensor.unpack %output inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %input : tensor<64x32x16xf32> -> tensor<256x128xf32>
705+
return %0 : tensor<256x128xf32>
706+
}
707+
708+
// -----
709+
702710
func.func @unpack_invalid_out_of_bound_outer_perm(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
703711
// expected-error@+1 {{invalid outer_dims_perm vector}}
704712
%0 = tensor.unpack %output outer_dims_perm = [2, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %input : tensor<8x8x32x16xf32> -> tensor<256x128xf32>

0 commit comments

Comments
 (0)