@@ -49,8 +49,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
49
49
using OpRewritePattern::OpRewritePattern;
50
50
51
51
Chipset chipset;
52
- ExtFOnFloat8RewritePattern (MLIRContext *ctx, Chipset chipset)
53
- : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
52
+ ExtFOnFloat8RewritePattern (MLIRContext *ctx, Chipset chipset,
53
+ PatternBenefit benefit)
54
+ : OpRewritePattern::OpRewritePattern(ctx, benefit), chipset(chipset) {}
54
55
55
56
LogicalResult matchAndRewrite (arith::ExtFOp op,
56
57
PatternRewriter &rewriter) const override ;
@@ -59,9 +60,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
59
60
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
60
61
bool saturateFP8 = false ;
61
62
TruncFToFloat8RewritePattern (MLIRContext *ctx, bool saturateFP8,
62
- Chipset chipset)
63
- : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8 ),
64
- chipset (chipset) {}
63
+ Chipset chipset, PatternBenefit benefit )
64
+ : OpRewritePattern::OpRewritePattern(ctx, benefit ),
65
+ saturateFP8 (saturateFP8), chipset(chipset) {}
65
66
Chipset chipset;
66
67
67
68
LogicalResult matchAndRewrite (arith::TruncFOp op,
@@ -81,9 +82,6 @@ struct ScalingExtFRewritePattern final
81
82
: OpRewritePattern<arith::ScalingExtFOp> {
82
83
using OpRewritePattern::OpRewritePattern;
83
84
84
- ScalingExtFRewritePattern (MLIRContext *ctx)
85
- : OpRewritePattern::OpRewritePattern(ctx) {}
86
-
87
85
LogicalResult matchAndRewrite (arith::ScalingExtFOp op,
88
86
PatternRewriter &rewriter) const override ;
89
87
};
@@ -92,9 +90,6 @@ struct ScalingTruncFRewritePattern final
92
90
: OpRewritePattern<arith::ScalingTruncFOp> {
93
91
using OpRewritePattern::OpRewritePattern;
94
92
95
- ScalingTruncFRewritePattern (MLIRContext *ctx)
96
- : OpRewritePattern::OpRewritePattern(ctx) {}
97
-
98
93
LogicalResult matchAndRewrite (arith::ScalingTruncFOp op,
99
94
PatternRewriter &rewriter) const override ;
100
95
};
@@ -667,19 +662,21 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
667
662
668
663
void mlir::arith::populateArithToAMDGPUConversionPatterns (
669
664
RewritePatternSet &patterns, bool convertFP8Arithmetic,
670
- bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
665
+ bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset,
666
+ PatternBenefit benefit) {
671
667
672
668
if (convertFP8Arithmetic) {
673
- patterns.add <ExtFOnFloat8RewritePattern>(patterns.getContext (), chipset);
674
- patterns.add <TruncFToFloat8RewritePattern>(patterns.getContext (),
675
- saturateFP8Truncf, chipset);
669
+ patterns.add <ExtFOnFloat8RewritePattern>(patterns.getContext (), chipset,
670
+ benefit);
671
+ patterns.add <TruncFToFloat8RewritePattern>(
672
+ patterns.getContext (), saturateFP8Truncf, chipset, benefit);
676
673
}
677
674
if (allowPackedF16Rtz)
678
- patterns.add <TruncfToFloat16RewritePattern>(patterns.getContext ());
675
+ patterns.add <TruncfToFloat16RewritePattern>(patterns.getContext (), benefit );
679
676
680
677
if (chipset >= kGfx950 ) {
681
- patterns.add <ScalingExtFRewritePattern>(patterns.getContext ());
682
- patterns.add <ScalingTruncFRewritePattern>(patterns.getContext ());
678
+ patterns.add <ScalingExtFRewritePattern>(patterns.getContext (), benefit );
679
+ patterns.add <ScalingTruncFRewritePattern>(patterns.getContext (), benefit );
683
680
}
684
681
}
685
682
0 commit comments