@@ -570,3 +570,46 @@ module attributes {transform.with_named_sequence} {
570570// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
571571// CHECK: }
572572// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
573+
574+ // -----
575+
576+ module {
577+ func.func @no_fuse_only_dps_consumer (%arg0: tensor <256 x256 xf32 >, %arg1: tensor <256 x256 xf32 >, %arg2: tensor <256 x256 xf32 >) -> (tensor <256 x256 xf32 >, tensor <258 x258 xf32 >) {
578+ %c0 = arith.constant 0 : index
579+ %c64 = arith.constant 64 : index
580+ %c256 = arith.constant 256 : index
581+ %cst = arith.constant 0.000000e+00 : f32
582+ %dest0 = tensor.empty () : tensor <256 x256 xf32 >
583+ %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args (%arg4 = %dest0 ) -> (tensor <256 x256 xf32 >) {
584+ %extracted_slice_1 = tensor.extract_slice %arg4 [%arg3 , 0 ] [64 , 256 ] [1 , 1 ] : tensor <256 x256 xf32 > to tensor <64 x256 xf32 >
585+ %extracted_slice_2 = tensor.extract_slice %arg0 [%arg3 , 0 ] [64 , 256 ] [1 , 1 ] : tensor <256 x256 xf32 > to tensor <64 x256 xf32 >
586+ %extracted_slice_3 = tensor.extract_slice %arg1 [%arg3 , 0 ] [64 , 256 ] [1 , 1 ] : tensor <256 x256 xf32 > to tensor <64 x256 xf32 >
587+ %3 = linalg.add ins (%extracted_slice_2 , %extracted_slice_3 : tensor <64 x256 xf32 >, tensor <64 x256 xf32 >) outs (%extracted_slice_1 : tensor <64 x256 xf32 >) -> tensor <64 x256 xf32 >
588+ %insert_slice = tensor.insert_slice %3 into %arg4 [%arg3 , 0 ] [64 , 256 ] [1 , 1 ] : tensor <64 x256 xf32 > into tensor <256 x256 xf32 >
589+ scf.yield %insert_slice : tensor <256 x256 xf32 >
590+ }
591+ %dest1 = tensor.empty () : tensor <258 x258 xf32 >
592+ %4 = tensor.insert_slice %1 into %dest1 [0 , 0 ] [256 , 256 ] [1 , 1 ] : tensor <256 x256 xf32 > into tensor <258 x258 xf32 >
593+ %5 = linalg.mul ins (%1 , %arg2 : tensor <256 x256 xf32 >, tensor <256 x256 xf32 >) outs (%dest0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
594+ return %5 , %4 : tensor <256 x256 xf32 >, tensor <258 x258 xf32 >
595+ }
596+ }
597+
598+ module attributes {transform.with_named_sequence } {
599+ transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
600+ %slice_ops = transform.structured.match ops {[" tensor.insert_slice" ]} in %arg1
601+ : (!transform.any_op ) -> !transform.any_op
602+ %slice_op , %other_slice = transform.split_handle %slice_ops : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
603+ %a , %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1
604+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
605+ transform.yield
606+ }
607+ }
608+ // CHECK: func.func @no_fuse_only_dps_consumer(
609+ // CHECK: %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} {
610+ // CHECK: linalg.add
611+ // CHECK: linalg.mul
612+ // CHECK: scf.yield
613+ // CHECK: }
614+ // CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
615+ // CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
0 commit comments