@@ -981,7 +981,8 @@ def main(variant_op: any_op_t()):
981
981
entry_point = "main" , debug_payload_root_tag = "payload"
982
982
),
983
983
)
984
- correct = """\
984
+ correct = dedent (
985
+ """\
985
986
#map = affine_map<(d0) -> (d0 * 16)>
986
987
#map1 = affine_map<(d0) -> (d0 * 64)>
987
988
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
@@ -995,27 +996,27 @@ def main(variant_op: any_op_t()):
995
996
%1 = tensor.empty() : tensor<1x4x16x64xi8>
996
997
%2 = tensor.empty() : tensor<4x1x64x64xi8>
997
998
%3 = tensor.empty() : tensor<1x1x16x64xi8>
998
- %4 = scf.forall (%arg2, %arg3) in (1, 4) shared_outs(%arg4 = %0) -> (tensor<16x256xi8>) {
999
- %5 = affine.apply #map(%arg2)
1000
- %6 = affine.apply #map1(%arg3)
1001
- %extracted_slice = tensor.extract_slice %arg0[%5, 0] [16, 256] [1, 1] : tensor<16x256xi8> to tensor<16x256xi8>
1002
- %extracted_slice_0 = tensor.extract_slice %arg1[0, %6] [256, 64] [1, 1] : tensor<256x256xi8> to tensor<256x64xi8>
1003
- %extracted_slice_1 = tensor.extract_slice %arg4[%5, %6] [16, 64] [1, 1] : tensor<16x256xi8> to tensor<16x64xi8>
999
+ %4 = linalg.fill ins(%c0_i32 : i32) outs(%3 : tensor<1x1x16x64xi8>) -> tensor<1x1x16x64xi8>
1000
+ %5 = scf.forall (%arg2, %arg3) in (1, 4) shared_outs(%arg4 = %0) -> (tensor<16x256xi8>) {
1001
+ %6 = affine.apply #map(%arg2)
1002
+ %7 = affine.apply #map1(%arg3)
1003
+ %extracted_slice = tensor.extract_slice %arg0[%6, 0] [16, 256] [1, 1] : tensor<16x256xi8> to tensor<16x256xi8>
1004
+ %extracted_slice_0 = tensor.extract_slice %arg1[0, %7] [256, 64] [1, 1] : tensor<256x256xi8> to tensor<256x64xi8>
1005
+ %extracted_slice_1 = tensor.extract_slice %arg4[%6, %7] [16, 64] [1, 1] : tensor<16x256xi8> to tensor<16x64xi8>
1004
1006
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %1 : tensor<16x256xi8> -> tensor<1x4x16x64xi8>
1005
1007
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %2 : tensor<256x64xi8> -> tensor<4x1x64x64xi8>
1006
- %7 = linalg.fill ins(%c0_i32 : i32) outs(%3 : tensor<1x1x16x64xi8>) -> tensor<1x1x16x64xi8>
1007
- %8 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_2 : tensor<1x4x16x64xi8>, tensor<4x1x64x64xi8>) outs(%7 : tensor<1x1x16x64xi8>) {
1008
+ %8 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_2 : tensor<1x4x16x64xi8>, tensor<4x1x64x64xi8>) outs(%4 : tensor<1x1x16x64xi8>) {
1008
1009
^bb0(%in: i8, %in_3: i8, %out: i8):
1009
1010
%9 = arith.muli %in, %in_3 : i8
1010
1011
%10 = arith.addi %out, %9 : i8
1011
1012
linalg.yield %10 : i8
1012
1013
} -> tensor<1x1x16x64xi8>
1013
1014
%unpack = tensor.unpack %8 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %extracted_slice_1 : tensor<1x1x16x64xi8> -> tensor<16x64xi8>
1014
1015
scf.forall.in_parallel {
1015
- tensor.parallel_insert_slice %unpack into %arg4[%5 , %6 ] [16, 64] [1, 1] : tensor<16x64xi8> into tensor<16x256xi8>
1016
+ tensor.parallel_insert_slice %unpack into %arg4[%6 , %7 ] [16, 64] [1, 1] : tensor<16x64xi8> into tensor<16x256xi8>
1016
1017
}
1017
1018
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
1018
- return %4 : tensor<16x256xi8>
1019
+ return %5 : tensor<16x256xi8>
1019
1020
}
1020
1021
}
1021
1022
module attributes {transform.with_named_sequence} {
@@ -1045,6 +1046,7 @@ def main(variant_op: any_op_t()):
1045
1046
}
1046
1047
}
1047
1048
"""
1049
+ )
1048
1050
filecheck (correct , mod )
1049
1051
1050
1052
0 commit comments