Skip to content

Commit 56bb652

Browse files
Disable Attention V operand transposition. (iree-org#19810)
This impacts the ability to horizontally fuse the matmuls that feed into `Q-K-V` transpose. The improvements seen with the change might have been due to reduction in copy overheads, which are no more an issue. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent dc4e900 commit 56bb652

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

compiler/src/iree/compiler/GlobalOptimization/Passes.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ static llvm::cl::opt<bool> clEnableTransposePropagation(
3434
llvm::cl::desc(
3535
"Enables propagation of transpose ops to improve fusion chances."),
3636
llvm::cl::init(true));
37+
static llvm::cl::opt<bool> clEnableAttentionVTranspose(
38+
"iree-global-opt-enable-attention-v-transpose",
39+
llvm::cl::desc("Enables transposition of v operand of attention ops,"),
40+
llvm::cl::init(true));
3741

3842
// TODO(hanchung): Remove the flag. We don't want to do early materialization by
3943
// default. Because it won't work for heterogeneous computing. This is not the
@@ -157,8 +161,11 @@ void buildGlobalOptimizationPassPipeline(
157161
.addPredicatedPass(
158162
clEnableTransposePropagation,
159163
[&]() {
160-
return createPropagateLinalgTransposePass(
161-
transformOptions.options.aggressiveTransposePropagation);
164+
PropagateLinalgTransposePassOptions options;
165+
options.enableAggressivePropagation =
166+
transformOptions.options.aggressiveTransposePropagation;
167+
options.enableAttentionVTranspose = clEnableAttentionVTranspose;
168+
return createPropagateLinalgTransposePass(options);
162169
})
163170
.addPass(IREE::Flow::createCanonicalizerPass)
164171
.addPass(mlir::createCSEPass);

compiler/src/iree/compiler/GlobalOptimization/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ def PropagateLinalgTransposePass :
115115
"Flag used for lit-testing sinking patterns only. Not for general usage">,
116116
Option<"testBubblingOnly", "test-bubbling-only", "bool", /*default=*/"false",
117117
"Flag used for lit-testing bubbling patterns only. Not for general usage">,
118+
Option<"enableAttentionVTranspose", "enable-attention-v-transpose", "bool",
119+
/*default=*/"true", "Enable transposition of attention v operand">,
118120
];
119121
}
120122

compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,15 +1092,19 @@ void PropagateLinalgTransposePass::runOnOperation() {
10921092
linalg::populateFoldReshapeOpsByExpansionPatterns(bubblingPatterns,
10931093
reshapePropagationFn);
10941094
linalg::FillOp::getCanonicalizationPatterns(bubblingPatterns, context);
1095-
linalg::ControlFusionFn bubbleTransposeControlFn =
1096-
[](OpOperand *fusedOperand) {
1097-
Operation *producer = fusedOperand->get().getDefiningOp();
1098-
Operation *consumer = fusedOperand->getOwner();
10991095

1100-
return IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer});
1101-
};
1102-
IREE::LinalgExt::populateBubbleTransposeFromLinalgExtOps(
1103-
bubblingPatterns, bubbleTransposeControlFn);
1096+
if (enableAttentionVTranspose) {
1097+
linalg::ControlFusionFn bubbleTransposeControlFn =
1098+
[](OpOperand *fusedOperand) {
1099+
Operation *producer = fusedOperand->get().getDefiningOp();
1100+
Operation *consumer = fusedOperand->getOwner();
1101+
1102+
return IREE::Flow::isNonNullAndOutsideDispatch(
1103+
{producer, consumer});
1104+
};
1105+
IREE::LinalgExt::populateBubbleTransposeFromLinalgExtOps(
1106+
bubblingPatterns, bubbleTransposeControlFn);
1107+
}
11041108
bubblingPatterns.insert<FuseTransposeWithProducerLinalgOp>(
11051109
context, enableAggressivePropagation);
11061110
bubblingPatterns.insert<BubbleTransposeThroughCollapseShape>(context);

0 commit comments

Comments
 (0)