Skip to content

Commit 2ff8bb6

Browse files
qedawkinsGroverkss
authored andcommitted
[mlir][SCF] Fix condition for fusability in consumer fusion API (llvm#115768)
It was previously allowing either a tilable or dps op to be fused. Both are required for consumer fusion.
1 parent 9909dfa commit 2ff8bb6

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,3 +570,46 @@ module attributes {transform.with_named_sequence} {
570570
// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
571571
// CHECK: }
572572
// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
573+
574+
// -----
575+
576+
module {
577+
func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) {
578+
%c0 = arith.constant 0 : index
579+
%c64 = arith.constant 64 : index
580+
%c256 = arith.constant 256 : index
581+
%cst = arith.constant 0.000000e+00 : f32
582+
%dest0 = tensor.empty() : tensor<256x256xf32>
583+
%1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
584+
%extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
585+
%extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
586+
%extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
587+
%3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
588+
%insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
589+
scf.yield %insert_slice : tensor<256x256xf32>
590+
}
591+
%dest1 = tensor.empty() : tensor<258x258xf32>
592+
%4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32>
593+
%5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
594+
return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32>
595+
}
596+
}
597+
598+
module attributes {transform.with_named_sequence} {
599+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
600+
%slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
601+
: (!transform.any_op) -> !transform.any_op
602+
%slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
603+
%a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1
604+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
605+
transform.yield
606+
}
607+
}
608+
// CHECK: func.func @no_fuse_only_dps_consumer(
609+
// CHECK: %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} {
610+
// CHECK: linalg.add
611+
// CHECK: linalg.mul
612+
// CHECK: scf.yield
613+
// CHECK: }
614+
// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
615+
// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]

0 commit comments

Comments
 (0)