Skip to content

Commit d47c553

Browse files
authored
fix transform tests (#103)
1 parent b2b8d72 commit d47c553

File tree

1 file changed

+19
-23
lines changed

1 file changed

+19
-23
lines changed

tests/test_transform.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -653,11 +653,10 @@ def tile_inner(target):
653653
%extracted_slice_1 = tensor.extract_slice %arg6[0, %arg3, %1, %2] [1, 1, 8, 8] [1, 1, 1, 1] : tensor<1x3x64x64xf32> to tensor<1x1x8x8xf32>
654654
%3 = scf.forall (%arg7, %arg8, %arg9) in (1, 8, 8) shared_outs(%arg10 = %extracted_slice_1) -> (tensor<1x1x8x8xf32>) {
655655
%extracted_slice_2 = tensor.extract_slice %extracted_slice[0, 0, %arg8, %arg9] [1, 1, 3, 3] [1, 1, 1, 1] : tensor<1x1x10x10xf32> to tensor<1x1x3x3xf32>
656-
%extracted_slice_3 = tensor.extract_slice %extracted_slice_0[%arg7, 0, 0, 0] [1, 1, 3, 3] [1, 1, 1, 1] : tensor<1x1x3x3xf32> to tensor<1x1x3x3xf32>
657-
%extracted_slice_4 = tensor.extract_slice %arg10[0, %arg7, %arg8, %arg9] [1, 1, 1, 1] [1, 1, 1, 1] : tensor<1x1x8x8xf32> to tensor<1x1x1x1xf32>
658-
%4 = linalg.conv_2d_nchw_fchw ins(%extracted_slice_2, %extracted_slice_3 : tensor<1x1x3x3xf32>, tensor<1x1x3x3xf32>) outs(%extracted_slice_4 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
656+
%extracted_slice_3 = tensor.extract_slice %arg10[0, 0, %arg8, %arg9] [1, 1, 1, 1] [1, 1, 1, 1] : tensor<1x1x8x8xf32> to tensor<1x1x1x1xf32>
657+
%4 = linalg.conv_2d_nchw_fchw ins(%extracted_slice_2, %extracted_slice_0 : tensor<1x1x3x3xf32>, tensor<1x1x3x3xf32>) outs(%extracted_slice_3 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
659658
scf.forall.in_parallel {
660-
tensor.parallel_insert_slice %4 into %arg10[0, %arg7, %arg8, %arg9] [1, 1, 1, 1] [1, 1, 1, 1] : tensor<1x1x1x1xf32> into tensor<1x1x8x8xf32>
659+
tensor.parallel_insert_slice %4 into %arg10[0, 0, %arg8, %arg9] [1, 1, 1, 1] [1, 1, 1, 1] : tensor<1x1x1x1xf32> into tensor<1x1x8x8xf32>
661660
}
662661
} {mapping = [#gpu.thread<x>, #gpu.thread<y>, #gpu.thread<z>]}
663662
scf.forall.in_parallel {
@@ -983,37 +982,34 @@ def main(variant_op: any_op_t()):
983982
)
984983
correct = dedent(
985984
"""\
986-
#map = affine_map<(d0) -> (d0 * 16)>
987-
#map1 = affine_map<(d0) -> (d0 * 64)>
988-
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
989-
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
990-
#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
985+
#map = affine_map<(d0) -> (d0 * 64)>
986+
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
987+
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
988+
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
991989
module {
992990
module attributes {transform.target_tag = "payload"} {
993991
func.func @matmul_i8_i8(%arg0: tensor<16x256xi8>, %arg1: tensor<256x256xi8>) -> tensor<16x256xi8> {
994992
%c0_i32 = arith.constant 0 : i32
995993
%0 = tensor.empty() : tensor<16x256xi8>
996994
%1 = tensor.empty() : tensor<1x4x16x64xi8>
995+
%pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %1 : tensor<16x256xi8> -> tensor<1x4x16x64xi8>
997996
%2 = tensor.empty() : tensor<4x1x64x64xi8>
998997
%3 = tensor.empty() : tensor<1x1x16x64xi8>
999998
%4 = linalg.fill ins(%c0_i32 : i32) outs(%3 : tensor<1x1x16x64xi8>) -> tensor<1x1x16x64xi8>
1000999
%5 = scf.forall (%arg2, %arg3) in (1, 4) shared_outs(%arg4 = %0) -> (tensor<16x256xi8>) {
1001-
%6 = affine.apply #map(%arg2)
1002-
%7 = affine.apply #map1(%arg3)
1003-
%extracted_slice = tensor.extract_slice %arg0[%6, 0] [16, 256] [1, 1] : tensor<16x256xi8> to tensor<16x256xi8>
1004-
%extracted_slice_0 = tensor.extract_slice %arg1[0, %7] [256, 64] [1, 1] : tensor<256x256xi8> to tensor<256x64xi8>
1005-
%extracted_slice_1 = tensor.extract_slice %arg4[%6, %7] [16, 64] [1, 1] : tensor<16x256xi8> to tensor<16x64xi8>
1006-
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %1 : tensor<16x256xi8> -> tensor<1x4x16x64xi8>
1007-
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %2 : tensor<256x64xi8> -> tensor<4x1x64x64xi8>
1008-
%8 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_2 : tensor<1x4x16x64xi8>, tensor<4x1x64x64xi8>) outs(%4 : tensor<1x1x16x64xi8>) {
1009-
^bb0(%in: i8, %in_3: i8, %out: i8):
1010-
%9 = arith.muli %in, %in_3 : i8
1011-
%10 = arith.addi %out, %9 : i8
1012-
linalg.yield %10 : i8
1000+
%6 = affine.apply #map(%arg3)
1001+
%extracted_slice = tensor.extract_slice %arg1[0, %6] [256, 64] [1, 1] : tensor<256x256xi8> to tensor<256x64xi8>
1002+
%extracted_slice_0 = tensor.extract_slice %arg4[0, %6] [16, 64] [1, 1] : tensor<16x256xi8> to tensor<16x64xi8>
1003+
%pack_1 = tensor.pack %extracted_slice outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %2 : tensor<256x64xi8> -> tensor<4x1x64x64xi8>
1004+
%7 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_1 : tensor<1x4x16x64xi8>, tensor<4x1x64x64xi8>) outs(%4 : tensor<1x1x16x64xi8>) {
1005+
^bb0(%in: i8, %in_2: i8, %out: i8):
1006+
%8 = arith.muli %in, %in_2 : i8
1007+
%9 = arith.addi %out, %8 : i8
1008+
linalg.yield %9 : i8
10131009
} -> tensor<1x1x16x64xi8>
1014-
%unpack = tensor.unpack %8 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %extracted_slice_1 : tensor<1x1x16x64xi8> -> tensor<16x64xi8>
1010+
%unpack = tensor.unpack %7 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %extracted_slice_0 : tensor<1x1x16x64xi8> -> tensor<16x64xi8>
10151011
scf.forall.in_parallel {
1016-
tensor.parallel_insert_slice %unpack into %arg4[%6, %7] [16, 64] [1, 1] : tensor<16x64xi8> into tensor<16x256xi8>
1012+
tensor.parallel_insert_slice %unpack into %arg4[0, %6] [16, 64] [1, 1] : tensor<16x64xi8> into tensor<16x256xi8>
10171013
}
10181014
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
10191015
return %5 : tensor<16x256xi8>

0 commit comments

Comments
 (0)