@@ -265,7 +265,7 @@ module {
265265 %c4 = arith.constant 4 : index
266266 %c64 = arith.constant 64 : index
267267 %c0 = arith.constant 0 : index
268- %1 = scf.forall (%arg3 , %arg4 ) in ( 2 , 2 ) shared_outs (%arg5 = %arg2 ) -> (tensor <64 x32 xf32 >) {
268+ %1 = scf.forall (%arg3 , %arg4 ) = ( 0 , 0 ) to ( 64 , 32 ) step ( 32 , 32 ) shared_outs (%arg5 = %arg2 ) -> (tensor <64 x32 xf32 >) {
269269 %extracted_slice = tensor.extract_slice %arg5 [%arg3 , %arg4 ] [32 , 32 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <32 x32 xf32 >
270270 %3 = linalg.generic {index ing_maps = [#map , #map , #map ], iterator_types = [" parallel" , " parallel" ]} ins (%arg0 , %arg1 : tensor <32 x32 xf32 >, tensor <32 x32 xf32 >) outs (%extracted_slice : tensor <32 x32 xf32 >) {
271271 ^bb0 (%in: f32 , %in_16: f32 , %out: f32 ):
@@ -292,26 +292,89 @@ module attributes {transform.with_named_sequence} {
292292 transform.yield
293293 }
294294}
295- // CHECK: #[[UNPACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
295+ // CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
296+ // CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)>
296297// CHECK: func.func @fuse_unpack_consumer_into_scf_forall(
297298// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
298299// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
299300// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
300301// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2048xf32>
301- // CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
302+ // CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
303+ // CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
304+ // CHECK-SAME: {
305+ // CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
306+ // CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
307+ // CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
308+ // CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
309+ // CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
310+ // CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
311+ // CHECK: %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[GENERIC_OUT]]
312+ // CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
313+ // CHECK-SAME: into %[[TILED_UNPACK_DEST]]
314+ // CHECK: scf.forall.in_parallel {
315+ // CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
316+ // CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
317+ // CHECK: }
318+ // CHECK: }
319+ // CHECK: return %[[FINAL_RESULT]]#1 :
320+
321+ // -----
322+
323+ #map = affine_map <(d0 , d1 ) -> (d0 , d1 )>
324+ module {
325+ func.func @fuse_unaligned_unpack_consumer_into_scf_forall (%arg0: tensor <32 x32 xf32 >, %arg1: tensor <32 x32 xf32 >, %arg2: tensor <64 x32 xf32 >) -> tensor <2047 xf32 > {
326+ %c4 = arith.constant 4 : index
327+ %c64 = arith.constant 64 : index
328+ %c0 = arith.constant 0 : index
329+ %1 = scf.forall (%arg3 , %arg4 ) = (0 , 0 ) to (64 , 32 ) step (32 , 32 ) shared_outs (%arg5 = %arg2 ) -> (tensor <64 x32 xf32 >) {
330+ %extracted_slice = tensor.extract_slice %arg5 [%arg3 , %arg4 ] [32 , 32 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <32 x32 xf32 >
331+ %3 = linalg.generic {index ing_maps = [#map , #map , #map ], iterator_types = [" parallel" , " parallel" ]} ins (%arg0 , %arg1 : tensor <32 x32 xf32 >, tensor <32 x32 xf32 >) outs (%extracted_slice : tensor <32 x32 xf32 >) {
332+ ^bb0 (%in: f32 , %in_16: f32 , %out: f32 ):
333+ %13 = arith.mulf %in , %in_16 : f32
334+ %14 = arith.addf %out , %13 : f32
335+ linalg.yield %14 : f32
336+ } -> tensor <32 x32 xf32 >
337+ scf.forall.in_parallel {
338+ tensor.parallel_insert_slice %3 into %arg5 [%arg3 , %arg4 ] [32 , 32 ] [1 , 1 ] : tensor <32 x32 xf32 > into tensor <64 x32 xf32 >
339+ }
340+ }
341+ %output = tensor.empty () : tensor <2047 xf32 >
342+ %unpack = tensor.unpack %1 outer_dims_perm = [0 ] inner_dims_pos = [0 ] inner_tiles = [32 ] into %output : tensor <64 x32 xf32 > -> tensor <2047 xf32 >
343+ return %unpack : tensor <2047 xf32 >
344+ }
345+ }
346+
347+ module attributes {transform.with_named_sequence } {
348+ transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
349+ %slice_op = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg1
350+ : (!transform.any_op ) -> !transform.any_op
351+ %a , %b = transform.test.fuse_consumer %slice_op
352+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
353+ transform.yield
354+ }
355+ }
356+ // CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
357+ // CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)>
358+ // CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall(
359+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
360+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
361+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
362+ // CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2047xf32>
363+ // CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
302364// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
303365// CHECK-SAME: {
304366// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
305367// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
306368// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
307- // CHECK: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_MAP]](%[[IV1]])
308- // CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1]
369+ // CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
370+ // CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
371+ // CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
309372// CHECK: %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[GENERIC_OUT]]
310373// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
311374// CHECK-SAME: into %[[TILED_UNPACK_DEST]]
312375// CHECK: scf.forall.in_parallel {
313376// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
314- // CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024 ] [1]
377+ // CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]] ] [1]
315378// CHECK: }
316379// CHECK: }
317380// CHECK: return %[[FINAL_RESULT]]#1 :
0 commit comments