@@ -320,6 +320,67 @@ module attributes {transform.with_named_sequence} {
320320
321321// -----
322322
323+ #map = affine_map <(d0 , d1 ) -> (d0 , d1 )>
324+ module {
325+ func.func @fuse_unaligned_unpack_consumer_into_scf_forall (%arg0: tensor <32 x32 xf32 >, %arg1: tensor <32 x32 xf32 >, %arg2: tensor <64 x32 xf32 >) -> tensor <2047 xf32 > {
326+ %c4 = arith.constant 4 : index
327+ %c64 = arith.constant 64 : index
328+ %c0 = arith.constant 0 : index
329+ %1 = scf.forall (%arg3 , %arg4 ) = (0 , 0 ) to (64 , 32 ) step (32 , 32 ) shared_outs (%arg5 = %arg2 ) -> (tensor <64 x32 xf32 >) {
330+ %extracted_slice = tensor.extract_slice %arg5 [%arg3 , %arg4 ] [32 , 32 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <32 x32 xf32 >
331+ %3 = linalg.generic {index ing_maps = [#map , #map , #map ], iterator_types = [" parallel" , " parallel" ]} ins (%arg0 , %arg1 : tensor <32 x32 xf32 >, tensor <32 x32 xf32 >) outs (%extracted_slice : tensor <32 x32 xf32 >) {
332+ ^bb0 (%in: f32 , %in_16: f32 , %out: f32 ):
333+ %13 = arith.mulf %in , %in_16 : f32
334+ %14 = arith.addf %out , %13 : f32
335+ linalg.yield %14 : f32
336+ } -> tensor <32 x32 xf32 >
337+ scf.forall.in_parallel {
338+ tensor.parallel_insert_slice %3 into %arg5 [%arg3 , %arg4 ] [32 , 32 ] [1 , 1 ] : tensor <32 x32 xf32 > into tensor <64 x32 xf32 >
339+ }
340+ }
341+ %output = tensor.empty () : tensor <2047 xf32 >
342+ %unpack = tensor.unpack %1 outer_dims_perm = [0 ] inner_dims_pos = [0 ] inner_tiles = [32 ] into %output : tensor <64 x32 xf32 > -> tensor <2047 xf32 >
343+ return %unpack : tensor <2047 xf32 >
344+ }
345+ }
346+
347+ module attributes {transform.with_named_sequence } {
348+ transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
349+ %slice_op = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg1
350+ : (!transform.any_op ) -> !transform.any_op
351+ %a , %b = transform.test.fuse_consumer %slice_op
352+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
353+ transform.yield
354+ }
355+ }
356+ // CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
357+ // CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)>
358+ // CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall(
359+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
360+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
361+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
362+ // CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2047xf32>
363+ // CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
364+ // CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
365+ // CHECK-SAME: {
366+ // CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
367+ // CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
368+ // CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
369+ // CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
370+ // CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
371+ // CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
372+ // CHECK: %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[GENERIC_OUT]]
373+ // CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
374+ // CHECK-SAME: into %[[TILED_UNPACK_DEST]]
375+ // CHECK: scf.forall.in_parallel {
376+ // CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
377+ // CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
378+ // CHECK: }
379+ // CHECK: }
380+ // CHECK: return %[[FINAL_RESULT]]#1 :
381+
382+ // -----
383+
323384#map = affine_map <(d0 , d1 ) -> (d0 , d1 )>
324385module {
325386 func.func @fuse_pack_consumer_into_scf_forall (%arg0: tensor <32 x32 xf32 >, %arg1: tensor <32 x32 xf32 >, %arg2: tensor <64 x32 xf32 >) -> tensor <4 x32 x16 xf32 > {
0 commit comments