Skip to content

Commit 8023308

Browse files
authored
fix transform tests (#100)
1 parent 960be01 commit 8023308

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

tests/test_transform.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,8 @@ def main(variant_op: any_op_t()):
981981
entry_point="main", debug_payload_root_tag="payload"
982982
),
983983
)
984-
correct = """\
984+
correct = dedent(
985+
"""\
985986
#map = affine_map<(d0) -> (d0 * 16)>
986987
#map1 = affine_map<(d0) -> (d0 * 64)>
987988
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
@@ -995,27 +996,27 @@ def main(variant_op: any_op_t()):
995996
%1 = tensor.empty() : tensor<1x4x16x64xi8>
996997
%2 = tensor.empty() : tensor<4x1x64x64xi8>
997998
%3 = tensor.empty() : tensor<1x1x16x64xi8>
998-
%4 = scf.forall (%arg2, %arg3) in (1, 4) shared_outs(%arg4 = %0) -> (tensor<16x256xi8>) {
999-
%5 = affine.apply #map(%arg2)
1000-
%6 = affine.apply #map1(%arg3)
1001-
%extracted_slice = tensor.extract_slice %arg0[%5, 0] [16, 256] [1, 1] : tensor<16x256xi8> to tensor<16x256xi8>
1002-
%extracted_slice_0 = tensor.extract_slice %arg1[0, %6] [256, 64] [1, 1] : tensor<256x256xi8> to tensor<256x64xi8>
1003-
%extracted_slice_1 = tensor.extract_slice %arg4[%5, %6] [16, 64] [1, 1] : tensor<16x256xi8> to tensor<16x64xi8>
999+
%4 = linalg.fill ins(%c0_i32 : i32) outs(%3 : tensor<1x1x16x64xi8>) -> tensor<1x1x16x64xi8>
1000+
%5 = scf.forall (%arg2, %arg3) in (1, 4) shared_outs(%arg4 = %0) -> (tensor<16x256xi8>) {
1001+
%6 = affine.apply #map(%arg2)
1002+
%7 = affine.apply #map1(%arg3)
1003+
%extracted_slice = tensor.extract_slice %arg0[%6, 0] [16, 256] [1, 1] : tensor<16x256xi8> to tensor<16x256xi8>
1004+
%extracted_slice_0 = tensor.extract_slice %arg1[0, %7] [256, 64] [1, 1] : tensor<256x256xi8> to tensor<256x64xi8>
1005+
%extracted_slice_1 = tensor.extract_slice %arg4[%6, %7] [16, 64] [1, 1] : tensor<16x256xi8> to tensor<16x64xi8>
10041006
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %1 : tensor<16x256xi8> -> tensor<1x4x16x64xi8>
10051007
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %2 : tensor<256x64xi8> -> tensor<4x1x64x64xi8>
1006-
%7 = linalg.fill ins(%c0_i32 : i32) outs(%3 : tensor<1x1x16x64xi8>) -> tensor<1x1x16x64xi8>
1007-
%8 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_2 : tensor<1x4x16x64xi8>, tensor<4x1x64x64xi8>) outs(%7 : tensor<1x1x16x64xi8>) {
1008+
%8 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_2 : tensor<1x4x16x64xi8>, tensor<4x1x64x64xi8>) outs(%4 : tensor<1x1x16x64xi8>) {
10081009
^bb0(%in: i8, %in_3: i8, %out: i8):
10091010
%9 = arith.muli %in, %in_3 : i8
10101011
%10 = arith.addi %out, %9 : i8
10111012
linalg.yield %10 : i8
10121013
} -> tensor<1x1x16x64xi8>
10131014
%unpack = tensor.unpack %8 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %extracted_slice_1 : tensor<1x1x16x64xi8> -> tensor<16x64xi8>
10141015
scf.forall.in_parallel {
1015-
tensor.parallel_insert_slice %unpack into %arg4[%5, %6] [16, 64] [1, 1] : tensor<16x64xi8> into tensor<16x256xi8>
1016+
tensor.parallel_insert_slice %unpack into %arg4[%6, %7] [16, 64] [1, 1] : tensor<16x64xi8> into tensor<16x256xi8>
10161017
}
10171018
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
1018-
return %4 : tensor<16x256xi8>
1019+
return %5 : tensor<16x256xi8>
10191020
}
10201021
}
10211022
module attributes {transform.with_named_sequence} {
@@ -1045,6 +1046,7 @@ def main(variant_op: any_op_t()):
10451046
}
10461047
}
10471048
"""
1049+
)
10481050
filecheck(correct, mod)
10491051

10501052

0 commit comments

Comments
 (0)