Skip to content

Commit bc4f75b

Browse files
committed
address comments
Signed-off-by: hanhanW <[email protected]>
1 parent 9e82cdf commit bc4f75b

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
106106
result tensor in the order in which they appear, i.e.
107107
`shape(result)[rank(result) + i] = inner_tiles[i]` for `0 <= i < k`.
108108
- The following relationship for the tiled dimensions holds:
109-
`shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`.
109+
`shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`,
110+
where (⌈/⌉ indicates CeilDiv).
111+
110112

111113
Example: If `inner_tiles = [16, 32]`, the result tensor has a shape of
112114
`...x16x32`. If `inner_dims_pos = [0, 1]`, the 0th source dimension is tiled
@@ -153,12 +155,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
153155
- If absent, it is assumed that for all inner tiles,
154156
`shape(source)[inner_dims_pos[i]] % inner_tiles[i] == 0`, i.e. all inner
155157
tiles divide perfectly the corresponding outer dimension in the result
156-
tensor.
158+
tensor. It is UB if the tile does not perfectly divide the dimension.
157159
- If present, it will pad along high dimensions (high-padding) to make the
158160
tile complete. Note that it is not allowed to have artificial padding that
159161
is not strictly required by linalg.pack (i.e., padding past what is needed
160162
to complete the last tile along each packed dimension). It is UB if extra
161163
padding is requested.
164+
It is not possible to verify the requirements statically with dynamic
165+
shapes, so they are treated as UB.
162166

163167
Example:
164168
```mlir

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4702,7 +4702,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
47024702
if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
47034703
packedType.getShape()))) {
47044704
return op->emitError("expected ")
4705-
<< expectedPackedType << " for the unpacked domain value, got "
4705+
<< expectedPackedType << " for the packed domain value, got "
47064706
<< packedType;
47074707
}
47084708
return success();

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1827,7 +1827,7 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
18271827

18281828
func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> {
18291829
%cst = arith.constant 0.0 : f32
1830-
// expected-error@+1 {{expected 'tensor<2x8xf32>' for the unpacked domain value, got 'tensor<3x8xf32>'}}
1830+
// expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
18311831
%0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0]
18321832
inner_tiles = [8] into %output
18331833
: tensor<9xf32> -> tensor<3x8xf32>
@@ -1839,23 +1839,23 @@ func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3
18391839
// The outer dims in the output tensor are incorrectly/unexpectedly transposed.
18401840
// This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose).
18411841
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> {
1842-
// expected-error@+1 {{expected 'tensor<16x4x32x16xf32>' for the unpacked domain value, got 'tensor<4x16x32x16xf32>'}}
1842+
// expected-error@+1 {{expected 'tensor<16x4x32x16xf32>' for the packed domain value, got 'tensor<4x16x32x16xf32>'}}
18431843
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32>
18441844
return %0 : tensor<4x16x32x16xf32>
18451845
}
18461846

18471847
// -----
18481848

18491849
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
1850-
// expected-error@+1 {{expected 'tensor<8x8x16x32xf32>' for the unpacked domain value, got 'tensor<8x7x16x32xf32>'}}
1850+
// expected-error@+1 {{expected 'tensor<8x8x16x32xf32>' for the packed domain value, got 'tensor<8x7x16x32xf32>'}}
18511851
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
18521852
return %0 : tensor<8x7x16x32xf32>
18531853
}
18541854

18551855
// -----
18561856

1857-
func.func @unpack_with_dropping_tiles(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
1858-
// expected-error@+1 {{expected 'tensor<2x8xf32>' for the unpacked domain value, got 'tensor<3x8xf32>'}}
1857+
func.func @unpack_with_artifical_tiles_that_are_dropped(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
1858+
// expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
18591859
%0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output
18601860
: tensor<3x8xf32> -> tensor<9xf32>
18611861
return %0 : tensor<9xf32>
@@ -1864,7 +1864,7 @@ func.func @unpack_with_dropping_tiles(%input: tensor<3x8xf32>, %output: tensor<9
18641864
// -----
18651865

18661866
func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
1867-
// expected-error@+1 {{expected 'tensor<8x32x4x32xf32>' for the unpacked domain value, got 'tensor<8x8x4x32xf32>'}}
1867+
// expected-error@+1 {{expected 'tensor<8x32x4x32xf32>' for the packed domain value, got 'tensor<8x8x4x32xf32>'}}
18681868
%0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
18691869
return %0 : tensor<256x128xf32>
18701870
}

0 commit comments

Comments
 (0)