@@ -437,3 +437,74 @@ module attributes {transform.with_named_sequence} {
437437// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
438438// CHECK: }
439439// CHECK: return %[[LOOP_RESULT1]]#1 :
440+
441+ // -----
442+
443+ // This test case checks fusion of consumer even if the producer has multiple uses.
444+ // The multiple uses of the producer essentially means that besides the consumer
445+ // op in concern, the only other uses of the producer are allowed in :-
446+ // 1. scf.yield
447+ // 2. tensor.parallel_insert_slice
448+
449+ module {
450+ module {
451+ func.func @fuse_consumer_for_multi_use_producer (%arg0: tensor <256 x512 xf32 >, %arg1: tensor <512 x256 xf32 >, %arg2: tensor <256 x256 xf32 >) -> (tensor <256 x256 xf32 >, tensor <256 x256 xf32 >) {
452+ %c0 = arith.constant 0 : index
453+ %c64 = arith.constant 64 : index
454+ %c256 = arith.constant 256 : index
455+ %cst = arith.constant 0.000000e+00 : f32
456+ %0 = tensor.empty () : tensor <256 x256 xf32 >
457+ %1 = linalg.fill ins (%cst : f32 ) outs (%0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
458+ %2:2 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args (%arg4 = %1 , %arg5 = %arg2 ) -> (tensor <256 x256 xf32 >, tensor <256 x256 xf32 >) {
459+ %3 = scf.for %arg6 = %c0 to %c256 step %c64 iter_args (%arg7 = %arg4 ) -> (tensor <256 x256 xf32 >) {
460+ %extracted_slice = tensor.extract_slice %arg7 [%arg3 , %arg6 ] [64 , 64 ] [1 , 1 ] : tensor <256 x256 xf32 > to tensor <64 x64 xf32 >
461+ %extracted_slice_0 = tensor.extract_slice %arg0 [%arg3 , 0 ] [64 , 512 ] [1 , 1 ] : tensor <256 x512 xf32 > to tensor <64 x512 xf32 >
462+ %extracted_slice_1 = tensor.extract_slice %arg1 [0 , %arg6 ] [512 , 64 ] [1 , 1 ] : tensor <512 x256 xf32 > to tensor <512 x64 xf32 >
463+ %5 = linalg.matmul ins (%extracted_slice_0 , %extracted_slice_1 : tensor <64 x512 xf32 >, tensor <512 x64 xf32 >) outs (%extracted_slice : tensor <64 x64 xf32 >) -> tensor <64 x64 xf32 >
464+ %inserted_slice = tensor.insert_slice %5 into %arg7 [%arg3 , %arg6 ] [64 , 64 ] [1 , 1 ] : tensor <64 x64 xf32 > into tensor <256 x256 xf32 >
465+ scf.yield %inserted_slice : tensor <256 x256 xf32 >
466+ }
467+ %4 = linalg.add ins (%3 , %arg5 : tensor <256 x256 xf32 >, tensor <256 x256 xf32 >) outs (%0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
468+ scf.yield %3 , %4 : tensor <256 x256 xf32 >, tensor <256 x256 xf32 >
469+ }
470+ return %2#0 , %2#1 : tensor <256 x256 xf32 >, tensor <256 x256 xf32 >
471+ }
472+ }
473+ module attributes {transform.with_named_sequence } {
474+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
475+ %0 = transform.structured.match ops {[" tensor.insert_slice" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
476+ %consumer , %fused_consumer = transform.test.fuse_consumer %0 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
477+ transform.yield
478+ }
479+ }
480+ }
481+ // CHECK: func.func @fuse_consumer_for_multi_use_producer(
482+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
483+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
484+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
485+ // CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
486+ // CHECK: %[[dest1:.*]] = linalg.fill
487+ // CHECK-SAME: outs(%[[dest0]] :
488+ // CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]]
489+ // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[ARG2]])
490+ // CHECK-SAME: {
491+ // CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]]
492+ // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[dest0]])
493+ // CHECK-SAME: {
494+ // CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
495+ // CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1]
496+ // CHECK: %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1]
497+ // CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul
498+ // CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
499+ // CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
500+ // CHECK: %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG1]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
501+ // CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
502+ // CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add
503+ // CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
504+ // CHECK-SAME: outs(%[[ADD_OUT_SLICE]] :
505+ // CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
506+ // CHECK: scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
507+ // CHECK: }
508+ // CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
509+ // CHECK: }
510+ // CHECK: return %[[LOOP_RESULT1]]#0, %[[LOOP_RESULT1]]#1 :
0 commit comments