Skip to content

Commit e122dcc

Browse files
authored
[Global Opt] Prevent expanding reduction dims (iree-org#20290)
Prevents fusing reshapes with reduction ops. Fixes llama fp8 perf regression due to interleaved parallel/reduction dimensions after llvm integrate (iree-org/llvm-project@813bbe0). Adding `memref::populateResolveRankedShapedTypeResultDimsPatterns` is unrelated, but I noticed some dynamic dimensions that weren't getting simplified during this pass. --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 5802af8 commit e122dcc

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2222
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
2323
#include "mlir/Dialect/Linalg/Utils/Utils.h"
24+
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
2425
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2526
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
2627
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -1001,6 +1002,7 @@ populateCommonCanonicalizationPatterns(MLIRContext *context,
10011002
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
10021003
tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
10031004
tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1005+
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
10041006
tensor::populateFoldTensorEmptyPatterns(patterns,
10051007
/*foldSingleUseOnly=*/false);
10061008
}
@@ -1140,7 +1142,7 @@ void PropagateLinalgTransposePass::runOnOperation() {
11401142
return false;
11411143
}
11421144
auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer);
1143-
if (!consumerLinalgOp) {
1145+
if (!consumerLinalgOp || consumerLinalgOp.getNumReductionLoops()) {
11441146
return false;
11451147
}
11461148
// Only reshape generic ops.

compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,3 +733,22 @@ util.func public @bubble_transpose_v_from_attention(%q: tensor<2x10x4096x64xf16>
733733
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[TRANS_V]], %[[ARG5]] : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>, f16)
734734
// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x10x4096x64xf16>)
735735
// CHECK: util.return %[[ATTN]] : tensor<2x10x4096x64xf16>
736+
737+
// -----
738+
739+
util.func public @dont_reshape_reduction(%arg0: tensor<16x4x4xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> {
740+
%empty1 = tensor.empty(): tensor<16x4x4xf32>
741+
%0 = linalg.transpose ins(%arg0 : tensor<16x4x4xf32>)
742+
outs(%empty1 : tensor<16x4x4xf32>) permutation = [0, 2, 1]
743+
%collapse = tensor.collapse_shape %0 [[0], [1, 2]] : tensor<16x4x4xf32> into tensor<16x16xf32>
744+
%empty2 = tensor.empty(): tensor<16x16xf32>
745+
%1 = linalg.matmul ins(%collapse, %arg1: tensor<16x16xf32>, tensor<16x16xf32>)
746+
outs(%empty2 : tensor<16x16xf32>) -> tensor<16x16xf32>
747+
748+
util.return %1 : tensor<16x16xf32>
749+
}
750+
// APROP-LABEL: util.func public @dont_reshape_reduction
751+
// APROP: %[[V0:.+]] = linalg.transpose
752+
// APROP: %[[V1:.+]] = tensor.collapse_shape %[[V0]]
753+
// APROP: %[[V2:.+]] = linalg.matmul ins(%[[V1]]
754+
// APROP: util.return %[[V2]]

0 commit comments

Comments
 (0)