Skip to content

Commit 99304ff

Browse files
authored
[Flow] Fix cloning of flow.tensor.transfer into dispatch (#19838)
Fixes case where `flow.tensor.transfer` was getting cloned into dispatches with an attention op by not cloning any `Flow` ops. This happens because `isAttentionMaskGenerator` returns true for all ops that are used by attention ops (including flow ops). Signed-off-by: Ian Wood <[email protected]>
1 parent 06eaead commit 99304ff

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,10 @@ static bool isAttentionMaskGenerator(Operation *op) {
807807
/// operations as roots.
808808
bool isClonableIntoDispatchOp(Operation *op,
809809
ClonableIntoDispatchOptions options) {
810+
if (isa<Flow::FlowDialect>(op->getDialect())) {
811+
return false;
812+
}
813+
810814
// TODO(#8637): `tensor.collapse_shape` and `tensor.expand_shape` are
811815
// trivially clonable too, but they cause problems
812816
// with bufferization. Make them clonable when fixed.

compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,3 +604,44 @@ util.func public @attention_clone_mask(%arg0: tensor<?x?xf16>,
604604
// CHECK-SAME: ins({{.+}}, %[[MASK]] :
605605
// CHECK: flow.return %[[ATTENTION]]
606606
// CHECK: return %[[DISPATCH]]
607+
608+
// -----
609+
610+
util.func public @dont_clone_flow_ops(%arg0: tensor<?x?xf16>, %arg1: tensor<?x?xf16>, %arg2: tensor<?x?xf16>, %arg3: tensor<?x?xi1>) -> tensor<?x?xf16> {
611+
%c0 = arith.constant 0 : index
612+
%c1 = arith.constant 1 : index
613+
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf16>
614+
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?xf16>
615+
%dim_1 = tensor.dim %arg2, %c1 : tensor<?x?xf16>
616+
%dim_2 = tensor.dim %arg3, %c0 : tensor<?x?xi1>
617+
%dim_3 = tensor.dim %arg3, %c1 : tensor<?x?xi1>
618+
%false = arith.constant false
619+
%true = arith.constant true
620+
%cst = arith.constant 1.000000e+00 : f16
621+
%0 = tensor.empty(%dim, %dim_0) : tensor<?x?xi1>
622+
%1 = tensor.empty(%dim, %dim_1) : tensor<?x?xf16>
623+
%2 = flow.tensor.transfer %arg3 : tensor<?x?xi1>{%dim_2, %dim_3} to #hal.device.promise<@dev_a>
624+
%3 = flow.dispatch.region -> (tensor<?x?xf16>{%dim, %dim_1}) {
625+
%4 = iree_linalg_ext.attention {
626+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
627+
affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
628+
affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
629+
affine_map<(d0, d1, d2, d3) -> ()>,
630+
affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
631+
affine_map<(d0, d1, d2, d3) -> (d0, d1)>]}
632+
ins(%arg0, %arg1, %arg2, %cst, %2 : tensor<?x?xf16>, tensor<?x?xf16>,
633+
tensor<?x?xf16>, f16, tensor<?x?xi1>) outs(%1 : tensor<?x?xf16>) {
634+
^bb0(%in: f32):
635+
iree_linalg_ext.yield %in : f32
636+
} -> tensor<?x?xf16>
637+
flow.return %4 : tensor<?x?xf16>
638+
}
639+
util.return %3 : tensor<?x?xf16>
640+
}
641+
// CHECK-LABEL: func public @dont_clone_flow_ops
642+
// CHECK: %[[MASK:.+]] = flow.tensor.transfer
643+
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
644+
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
645+
// CHECK-SAME: ins({{.+}}, %[[MASK]] :
646+
// CHECK: flow.return %[[ATTENTION]]
647+
// CHECK: return %[[DISPATCH]]

0 commit comments

Comments
 (0)