diff --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h index f4a9518839224..fd144edf77452 100644 --- a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h +++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h @@ -28,12 +28,10 @@ namespace arith { /// is set, values outside the range of the destination type are clamped /// to the largest value of that type instead of being rewritten to Inf (aka /// NaN). -void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, - bool convertFP8Arithmetic, - bool saturateFP8Truncf, - bool allowPackedF16Rtz, - amdgpu::Chipset chipset, - PatternBenefit benefit = 1); +void populateArithToAMDGPUConversionPatterns( + RewritePatternSet &patterns, bool convertFP8Arithmetic, + bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc, + amdgpu::Chipset chipset, PatternBenefit benefit = 1); } // namespace arith } // namespace mlir diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 8230591123661..3d6f6cab42244 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -690,8 +690,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, void mlir::arith::populateArithToAMDGPUConversionPatterns( RewritePatternSet &patterns, bool convertFP8Arithmetic, - bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset, - PatternBenefit benefit) { + bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc, + Chipset chipset, PatternBenefit benefit) { if (convertFP8Arithmetic) { patterns.add(patterns.getContext(), chipset, @@ -702,7 +702,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns( if (allowPackedF16Rtz) patterns.add(patterns.getContext(), benefit); - if (chipset >= kGfx950) { + if (supportsScaledExtTrunc) { patterns.add(patterns.getContext(), benefit); patterns.add(patterns.getContext(), benefit); } @@ -720,9 +720,10 @@ void ArithToAMDGPUConversionPass::runOnOperation() { bool convertFP8Arithmetic = *maybeChipset == kGfx942 || hasOcpFp8(*maybeChipset); + bool supportsScaledExtTrunc = *maybeChipset == kGfx950; arith::populateArithToAMDGPUConversionPatterns( patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz, - *maybeChipset); + supportsScaledExtTrunc, *maybeChipset); if (failed(applyPatternsGreedily(op, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir index 1d36be1108d26..a2b0aef594e61 100644 --- a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir +++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s +// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1100" | FileCheck %s --check-prefix=CHECK-GFX1100 // CHECK-LABEL: @conversion_f8_f32_fallback // CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> @@ -241,6 +242,9 @@ func.func @conversion_broadcast(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vec // ----- +// CHECK-GFX1100-LABEL: @conversion_scalar +// CHECK-GFX1100: arith.scaling_extf + // CHECK-LABEL: @conversion_scalar // CHECK: %[[SCALE_F32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32 // CHECK-NEXT: %[[SPLAT_IN:.+]] = vector.broadcast %arg0 : f8E5M2 to vector<1xf8E5M2> diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir index 90a86084ac93f..bc2c6a5aa0275 100644 --- a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir +++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s +// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1100" | FileCheck %s --check-prefix=CHECK-GFX1100 // CHECK-LABEL: @conversion_f8_fallback // CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf8E5M2> @@ -163,6 +164,9 @@ func.func @conversion_broadcast(%in: vector<4xf32>, %scale: f8E8M0FNU) -> vector // ----- +// CHECK-GFX1100-LABEL: @conversion_scalar +// CHECK-GFX1100: arith.scaling_truncf + // CHECK-LABEL: @conversion_scalar // CHECK: %[[SCALE_F32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32 // CHECK-NEXT: %[[SPLAT_IN:.+]] = vector.broadcast %arg0 : f32 to vector<1xf32>