Skip to content

Commit 317ea4e

Browse files
krzysz00mahesh-attarde
authored andcommitted
[mlir][ArithToAMDGPU][NFC] Add PatternBenefit (llvm#150091)
Since there may be caseses where these patterns are run alongside the generic patterns from ArithExpandOps, add a PatternBenefit argument to allow these architecture-specific patterns to be prioritized.
1 parent 5c211d3 commit 317ea4e

File tree

2 files changed

+18
-19
lines changed

2 files changed

+18
-19
lines changed

mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
1111

1212
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
13+
#include "mlir/IR/PatternMatch.h"
1314
#include <memory>
1415
#include <string>
1516

@@ -31,7 +32,8 @@ void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
3132
bool convertFP8Arithmetic,
3233
bool saturateFP8Truncf,
3334
bool allowPackedF16Rtz,
34-
amdgpu::Chipset chipset);
35+
amdgpu::Chipset chipset,
36+
PatternBenefit benefit = 1);
3537
} // namespace arith
3638
} // namespace mlir
3739

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
4949
using OpRewritePattern::OpRewritePattern;
5050

5151
Chipset chipset;
52-
ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
53-
: OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
52+
ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset,
53+
PatternBenefit benefit)
54+
: OpRewritePattern::OpRewritePattern(ctx, benefit), chipset(chipset) {}
5455

5556
LogicalResult matchAndRewrite(arith::ExtFOp op,
5657
PatternRewriter &rewriter) const override;
@@ -59,9 +60,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
5960
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
6061
bool saturateFP8 = false;
6162
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
62-
Chipset chipset)
63-
: OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
64-
chipset(chipset) {}
63+
Chipset chipset, PatternBenefit benefit)
64+
: OpRewritePattern::OpRewritePattern(ctx, benefit),
65+
saturateFP8(saturateFP8), chipset(chipset) {}
6566
Chipset chipset;
6667

6768
LogicalResult matchAndRewrite(arith::TruncFOp op,
@@ -81,9 +82,6 @@ struct ScalingExtFRewritePattern final
8182
: OpRewritePattern<arith::ScalingExtFOp> {
8283
using OpRewritePattern::OpRewritePattern;
8384

84-
ScalingExtFRewritePattern(MLIRContext *ctx)
85-
: OpRewritePattern::OpRewritePattern(ctx) {}
86-
8785
LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
8886
PatternRewriter &rewriter) const override;
8987
};
@@ -92,9 +90,6 @@ struct ScalingTruncFRewritePattern final
9290
: OpRewritePattern<arith::ScalingTruncFOp> {
9391
using OpRewritePattern::OpRewritePattern;
9492

95-
ScalingTruncFRewritePattern(MLIRContext *ctx)
96-
: OpRewritePattern::OpRewritePattern(ctx) {}
97-
9893
LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
9994
PatternRewriter &rewriter) const override;
10095
};
@@ -667,19 +662,21 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
667662

668663
void mlir::arith::populateArithToAMDGPUConversionPatterns(
669664
RewritePatternSet &patterns, bool convertFP8Arithmetic,
670-
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
665+
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset,
666+
PatternBenefit benefit) {
671667

672668
if (convertFP8Arithmetic) {
673-
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
674-
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
675-
saturateFP8Truncf, chipset);
669+
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset,
670+
benefit);
671+
patterns.add<TruncFToFloat8RewritePattern>(
672+
patterns.getContext(), saturateFP8Truncf, chipset, benefit);
676673
}
677674
if (allowPackedF16Rtz)
678-
patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
675+
patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext(), benefit);
679676

680677
if (chipset >= kGfx950) {
681-
patterns.add<ScalingExtFRewritePattern>(patterns.getContext());
682-
patterns.add<ScalingTruncFRewritePattern>(patterns.getContext());
678+
patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), benefit);
679+
patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), benefit);
683680
}
684681
}
685682

0 commit comments

Comments
 (0)