Skip to content

Commit 11fe5cd

Browse files
authored
[Codegen] Add canonicalization pass to track lowering configs (iree-org#19138)
This allows us to retain lowering configs (or other discardable attributes we need) through canonicalization patterns. This patch only replaces canonicalizer uses before bufferization/vectorization as currently those are the only places where we rely on lowering configs.
1 parent 1c43bcd commit 11fe5cd

File tree

10 files changed

+167
-35
lines changed

10 files changed

+167
-35
lines changed

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ iree_compiler_cc_library(
9393
"BufferizeCopyOnlyDispatchesPass.cpp",
9494
"CleanupBufferAllocViewPass.cpp",
9595
"ConcretizePadResultShape.cpp",
96+
"ConfigTrackingCanonicalizer.cpp",
9697
"ConvertBf16ArithToF32.cpp",
9798
"ConvertBf16ToUInt16Buffers.cpp",
9899
"ConvertToDestinationPassingStylePass.cpp",

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ iree_cc_library(
8585
"BufferizeCopyOnlyDispatchesPass.cpp"
8686
"CleanupBufferAllocViewPass.cpp"
8787
"ConcretizePadResultShape.cpp"
88+
"ConfigTrackingCanonicalizer.cpp"
8889
"ConvertBf16ArithToF32.cpp"
8990
"ConvertBf16ToUInt16Buffers.cpp"
9091
"ConvertToDestinationPassingStylePass.cpp"

compiler/src/iree/compiler/Codegen/Common/ConcretizePadResultShape.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
#include "iree/compiler/Codegen/Common/Passes.h"
8+
#include "iree/compiler/Codegen/Common/Transforms.h"
89
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
910
#include "llvm/Support/Debug.h"
1011
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -138,7 +139,11 @@ class ConcretizePadResultShapePass final
138139
{
139140
RewritePatternSet patterns(context);
140141
populateConcretizePadResultShapePatterns(patterns);
141-
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
142+
GreedyRewriteConfig config;
143+
auto listener = ConfigTrackingListener();
144+
config.listener = &listener;
145+
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns),
146+
config))) {
142147
return signalPassFailure();
143148
}
144149
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright 2024 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Common/Passes.h"
8+
#include "iree/compiler/Codegen/Common/Transforms.h"
9+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
10+
#include "mlir/IR/PatternMatch.h"
11+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
12+
13+
#define DEBUG_TYPE "iree-codegen-config-tracking-canonicalizer"
14+
15+
namespace mlir::iree_compiler {
16+
17+
#define GEN_PASS_DEF_CONFIGTRACKINGCANONICALIZERPASS
18+
#include "iree/compiler/Codegen/Common/Passes.h.inc"
19+
20+
static Operation *skipCastsDefiningOp(Value v) {
21+
auto producer = v.getDefiningOp();
22+
while (auto castProducer = dyn_cast<tensor::CastOp>(producer)) {
23+
producer = castProducer.getSource().getDefiningOp();
24+
}
25+
return producer;
26+
}
27+
28+
void ConfigTrackingListener::notifyOperationReplaced(Operation *op,
29+
ValueRange replacement) {
30+
// We have no way to track replacements without a producer.
31+
if (replacement.empty()) {
32+
return;
33+
}
34+
35+
IREE::Codegen::LoweringConfigAttrInterface loweringConfig =
36+
getLoweringConfig(op);
37+
if (!loweringConfig) {
38+
return;
39+
}
40+
41+
// Must have a producer of the same type to track the lowering config.
42+
auto producer = skipCastsDefiningOp(replacement.front());
43+
if (!producer || producer->getName() != op->getName()) {
44+
return;
45+
}
46+
47+
for (auto v : replacement.drop_front()) {
48+
// Conservatively require that all replacements are produced by the same
49+
// operation.
50+
if (skipCastsDefiningOp(v) != producer) {
51+
return;
52+
}
53+
}
54+
55+
// No need to add the lowering config if it's already present.
56+
if (getLoweringConfig(producer)) {
57+
return;
58+
}
59+
60+
setLoweringConfig(producer, loweringConfig);
61+
}
62+
63+
namespace {
64+
65+
/// Add the corresponding fast-math flags to operations given a floating-point
66+
/// optimization mode.
67+
// TODO: For now we only allow default flags, such as arithmetic reassociation.
68+
struct ConfigTrackingCanonicalizerPass final
69+
: impl::ConfigTrackingCanonicalizerPassBase<
70+
ConfigTrackingCanonicalizerPass> {
71+
public:
72+
using impl::ConfigTrackingCanonicalizerPassBase<
73+
ConfigTrackingCanonicalizerPass>::ConfigTrackingCanonicalizerPassBase;
74+
/// Initialize the canonicalizer by building the set of patterns used during
75+
/// execution.
76+
LogicalResult initialize(MLIRContext *context) override {
77+
// Inherit the same config defaults from the upstream canonicalizer pass.
78+
config.useTopDownTraversal = true;
79+
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Normal;
80+
81+
RewritePatternSet owningPatterns(context);
82+
for (auto *dialect : context->getLoadedDialects())
83+
dialect->getCanonicalizationPatterns(owningPatterns);
84+
for (RegisteredOperationName op : context->getRegisteredOperations())
85+
op.getCanonicalizationPatterns(owningPatterns, context);
86+
87+
patterns =
88+
std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
89+
return success();
90+
}
91+
92+
void runOnOperation() override {
93+
// Canonicalization is best-effort. Non-convergence is not a pass failure.
94+
auto listener = ConfigTrackingListener();
95+
config.listener = &listener;
96+
LogicalResult didConverge =
97+
applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
98+
if (this->testConvergence && failed(didConverge)) {
99+
getOperation()->emitError("Canonicalizer failed to converge");
100+
return signalPassFailure();
101+
}
102+
}
103+
GreedyRewriteConfig config;
104+
std::shared_ptr<const FrozenRewritePatternSet> patterns;
105+
};
106+
107+
} // namespace
108+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ def BufferizeCopyOnlyDispatchesPass :
4444
}];
4545
}
4646

47+
def ConfigTrackingCanonicalizerPass :
48+
Pass<"iree-codegen-config-tracking-canonicalize", ""> {
49+
let summary = "Codegen specific canonicalization pass that tracks lowering configs";
50+
let options = [
51+
Option<"testConvergence", "test-convergence", "bool",
52+
/*default=*/"false", "Fails if the patterns fail to converge">
53+
];
54+
}
55+
4756
def CleanupBufferAllocViewPass :
4857
InterfacePass<"iree-codegen-cleanup-buffer-alloc-view", "mlir::FunctionOpInterface"> {
4958
let summary =

compiler/src/iree/compiler/Codegen/Common/Transforms.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ struct OneShotBufferizationOptions;
1818

1919
namespace mlir::iree_compiler {
2020

21+
/// Common helper class for tracking lowering configs through pattern
22+
/// applications.
23+
class ConfigTrackingListener : public RewriterBase::Listener {
24+
public:
25+
ConfigTrackingListener() = default;
26+
void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
27+
};
28+
2129
using IGEMMConfigFn =
2230
std::function<LogicalResult(linalg::GenericOp, IREE::LinalgExt::Im2colOp)>;
2331
using IGEMMControlFn = std::function<bool(Operation *)>;

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ static void addTileAndDistributePasses(OpPassManager &funcPassManager) {
121121
funcPassManager.addPass(createConvertToDestinationPassingStylePass());
122122
funcPassManager.addPass(createFoldAffineMinInDistributedLoopsPass());
123123
}
124-
funcPassManager.addPass(createCanonicalizerPass());
124+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
125125
funcPassManager.addPass(createCSEPass());
126126
funcPassManager.addPass(createFuseTensorPadWithConsumerPass());
127127
funcPassManager.addPass(createConcretizePadResultShapePass());
@@ -425,7 +425,7 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager,
425425
funcPassManager.addPass(createTensorToVectorVectorizePadPass());
426426
if (pipelineOpt.decomposePackUnPackOps) {
427427
funcPassManager.addPass(createDecomposePackUnPackOpsPass());
428-
funcPassManager.addPass(createCanonicalizerPass());
428+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
429429
funcPassManager.addPass(createCSEPass());
430430
}
431431

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ static void tileAndDistributeToWorkgroup(
197197
// TODO(#16421): Disable decomposition due to failure in bufferization.
198198
// funcPassManager.addPass(
199199
// IREE::LinalgExt::createTileAndDecomposeAttentionPass());
200-
funcPassManager.addPass(createCanonicalizerPass());
200+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
201201
funcPassManager.addPass(createCSEPass());
202202
}
203203

@@ -238,13 +238,13 @@ static void addGPUVectorizationPasses(OpPassManager &funcPassManager,
238238
void addGPUVectorizationPassPipeline(OpPassManager &funcPassManager) {
239239
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
240240

241-
funcPassManager.addPass(createCanonicalizerPass());
242-
funcPassManager.addPass(createCanonicalizerPass());
241+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
242+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
243243
funcPassManager.addPass(createCSEPass());
244244

245245
// Distribute linalg onto threads within the workgroup.
246246
funcPassManager.addPass(createGPUTensorTilePass());
247-
funcPassManager.addPass(createCanonicalizerPass());
247+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
248248
funcPassManager.addPass(createCSEPass());
249249

250250
// Linalg -> vector
@@ -365,7 +365,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
365365
GPUApplyTilingLevelPassOptions options;
366366
options.tilingLevel = IREE::GPU::TilingLevel::Reduction;
367367
funcPassManager.addPass(createGPUApplyTilingLevelPass(options));
368-
funcPassManager.addPass(createCanonicalizerPass());
368+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
369369
funcPassManager.addPass(createCSEPass());
370370
}
371371

@@ -384,15 +384,15 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
384384
}
385385

386386
funcPassManager.addPass(createPropagateReshapesByExpansionPass());
387-
funcPassManager.addPass(createCanonicalizerPass());
387+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
388388
funcPassManager.addPass(createCSEPass());
389389

390390
// Step 4. Tile and fuse tileable ops to subgroups/threads.
391391
{
392392
GPUApplyTilingLevelPassOptions options;
393393
options.tilingLevel = IREE::GPU::TilingLevel::Thread;
394394
funcPassManager.addPass(createGPUApplyTilingLevelPass(options));
395-
funcPassManager.addPass(createCanonicalizerPass());
395+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
396396
funcPassManager.addPass(createCSEPass());
397397
}
398398
{
@@ -406,7 +406,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
406406
funcPassManager.addPass(iree_compiler::createNormalizeLoopBoundsPass(
407407
NormalizeLoopBoundsPassOptions{/*normalizeFor=*/false,
408408
/*normalizeForall=*/true}));
409-
funcPassManager.addPass(createCanonicalizerPass());
409+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
410410
funcPassManager.addPass(createCSEPass());
411411

412412
// TODO: This LICM instance is load bearing due to brittleness of the
@@ -489,13 +489,13 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
489489
void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) {
490490
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
491491

492-
funcPassManager.addPass(createCanonicalizerPass());
493-
funcPassManager.addPass(createCanonicalizerPass());
492+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
493+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
494494
funcPassManager.addPass(createCSEPass());
495495

496496
// Distribute linalg onto threads within the workgroup.
497497
funcPassManager.addPass(createGPUTilePass());
498-
funcPassManager.addPass(createCanonicalizerPass());
498+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
499499
funcPassManager.addPass(createCSEPass());
500500
funcPassManager.addPass(
501501
IREE::LinalgExt::createDecomposeWinogradTransformPass());
@@ -512,7 +512,7 @@ void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) {
512512
// Post bufferization optimizations.
513513
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
514514
funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
515-
funcPassManager.addPass(createCanonicalizerPass());
515+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
516516
funcPassManager.addPass(createCSEPass());
517517
funcPassManager.addPass(createOptimizeVectorTransferPass());
518518
funcPassManager.addPass(createOptimizeTensorInsertExtractSlicesPass());
@@ -526,8 +526,8 @@ void addGPUMatmulSimtPassPipeline(OpPassManager &funcPassManager,
526526
const GPUPipelineOptions &options) {
527527
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
528528

529-
funcPassManager.addPass(createCanonicalizerPass());
530-
funcPassManager.addPass(createCanonicalizerPass());
529+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
530+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
531531
funcPassManager.addPass(createCSEPass());
532532

533533
funcPassManager.addPass(createGPUTensorTileToSerialLoopsPass());
@@ -727,8 +727,8 @@ void addGPUTransposePassPipeline(OpPassManager &funcPassManager,
727727
const GPUPipelineOptions &options) {
728728
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
729729

730-
funcPassManager.addPass(createCanonicalizerPass());
731-
funcPassManager.addPass(createCanonicalizerPass());
730+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
731+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
732732
funcPassManager.addPass(createCSEPass());
733733

734734
funcPassManager.addPass(
@@ -844,7 +844,7 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
844844
funcPassManager.addPass(
845845
IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass());
846846

847-
funcPassManager.addPass(createCanonicalizerPass());
847+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
848848
funcPassManager.addPass(createCSEPass());
849849
funcPassManager.addPass(createGPUPromoteMatmulOperandsPass());
850850

@@ -855,12 +855,12 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
855855
options.allowZeroSlices = true;
856856
funcPassManager.addPass(createGPUApplyTilingLevelPass(options));
857857
funcPassManager.addPass(affine::createLoopCoalescingPass());
858-
funcPassManager.addPass(createCanonicalizerPass());
858+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
859859
funcPassManager.addPass(createCSEPass());
860860
}
861861

862862
funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
863-
funcPassManager.addPass(createCanonicalizerPass());
863+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
864864
funcPassManager.addPass(createCSEPass());
865865

866866
// Set anchors at tensor level for vector distribution later and hoist out
@@ -927,9 +927,9 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
927927
void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) {
928928
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
929929
funcPassManager.addPass(createRematerializeParallelOpsPass());
930-
funcPassManager.addPass(createCanonicalizerPass());
930+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
931931
funcPassManager.addPass(createGPUTileReductionPass());
932-
funcPassManager.addPass(createCanonicalizerPass());
932+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
933933
funcPassManager.addPass(createCSEPass());
934934

935935
// Linalg -> vector
@@ -970,11 +970,11 @@ void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) {
970970

971971
void addGPUPackUnPackPasses(OpPassManager &funcPassManager) {
972972
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
973-
funcPassManager.addPass(createCanonicalizerPass());
973+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
974974
funcPassManager.addPass(createCSEPass());
975975

976976
funcPassManager.addPass(createGPUTensorTilePass());
977-
funcPassManager.addPass(createCanonicalizerPass());
977+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
978978
funcPassManager.addPass(createCSEPass());
979979

980980
funcPassManager.addPass(createDecomposePackUnPackOpsPass(
@@ -1165,7 +1165,7 @@ static void buildLLVMGPUCodegenConfigurationPassPipelineImpl(
11651165
addCommonTargetExecutablePreprocessingPasses(funcPassManager);
11661166
addEncodingToNopPasses(funcPassManager);
11671167
funcPassManager.addPass(createBlockDynamicDimensionsPass);
1168-
funcPassManager.addPass(createCanonicalizerPass);
1168+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass);
11691169
funcPassManager.addPass(createCSEPass);
11701170
}
11711171
modulePassManager.addPass(createMaterializeUserConfigsPass());

0 commit comments

Comments
 (0)