@@ -49,8 +49,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
4949 using OpRewritePattern::OpRewritePattern;
5050
5151 Chipset chipset;
52- ExtFOnFloat8RewritePattern (MLIRContext *ctx, Chipset chipset)
53- : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
52+ ExtFOnFloat8RewritePattern (MLIRContext *ctx, Chipset chipset,
53+ PatternBenefit benefit)
54+ : OpRewritePattern::OpRewritePattern(ctx, benefit), chipset(chipset) {}
5455
5556 LogicalResult matchAndRewrite (arith::ExtFOp op,
5657 PatternRewriter &rewriter) const override ;
@@ -59,9 +60,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
5960struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
6061 bool saturateFP8 = false ;
6162 TruncFToFloat8RewritePattern (MLIRContext *ctx, bool saturateFP8,
62- Chipset chipset)
63- : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8 ),
64- chipset (chipset) {}
63+ Chipset chipset, PatternBenefit benefit )
64+ : OpRewritePattern::OpRewritePattern(ctx, benefit ),
65+ saturateFP8 (saturateFP8), chipset(chipset) {}
6566 Chipset chipset;
6667
6768 LogicalResult matchAndRewrite (arith::TruncFOp op,
@@ -81,9 +82,6 @@ struct ScalingExtFRewritePattern final
8182 : OpRewritePattern<arith::ScalingExtFOp> {
8283 using OpRewritePattern::OpRewritePattern;
8384
84- ScalingExtFRewritePattern (MLIRContext *ctx)
85- : OpRewritePattern::OpRewritePattern(ctx) {}
86-
8785 LogicalResult matchAndRewrite (arith::ScalingExtFOp op,
8886 PatternRewriter &rewriter) const override ;
8987};
@@ -92,9 +90,6 @@ struct ScalingTruncFRewritePattern final
9290 : OpRewritePattern<arith::ScalingTruncFOp> {
9391 using OpRewritePattern::OpRewritePattern;
9492
95- ScalingTruncFRewritePattern (MLIRContext *ctx)
96- : OpRewritePattern::OpRewritePattern(ctx) {}
97-
9893 LogicalResult matchAndRewrite (arith::ScalingTruncFOp op,
9994 PatternRewriter &rewriter) const override ;
10095};
@@ -667,19 +662,21 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
667662
668663void mlir::arith::populateArithToAMDGPUConversionPatterns (
669664 RewritePatternSet &patterns, bool convertFP8Arithmetic,
670- bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
665+ bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset,
666+ PatternBenefit benefit) {
671667
672668 if (convertFP8Arithmetic) {
673- patterns.add <ExtFOnFloat8RewritePattern>(patterns.getContext (), chipset);
674- patterns.add <TruncFToFloat8RewritePattern>(patterns.getContext (),
675- saturateFP8Truncf, chipset);
669+ patterns.add <ExtFOnFloat8RewritePattern>(patterns.getContext (), chipset,
670+ benefit);
671+ patterns.add <TruncFToFloat8RewritePattern>(
672+ patterns.getContext (), saturateFP8Truncf, chipset, benefit);
676673 }
677674 if (allowPackedF16Rtz)
678- patterns.add <TruncfToFloat16RewritePattern>(patterns.getContext ());
675+ patterns.add <TruncfToFloat16RewritePattern>(patterns.getContext (), benefit );
679676
680677 if (chipset >= kGfx950 ) {
681- patterns.add <ScalingExtFRewritePattern>(patterns.getContext ());
682- patterns.add <ScalingTruncFRewritePattern>(patterns.getContext ());
678+ patterns.add <ScalingExtFRewritePattern>(patterns.getContext (), benefit );
679+ patterns.add <ScalingTruncFRewritePattern>(patterns.getContext (), benefit );
683680 }
684681}
685682
0 commit comments