Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@ namespace arith {
/// is set, values outside the range of the destination type are clamped
/// to the largest value of that type instead of being rewritten to Inf (aka
/// NaN).
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
bool convertFP8Arithmetic,
bool saturateFP8Truncf,
bool allowPackedF16Rtz,
amdgpu::Chipset chipset,
PatternBenefit benefit = 1);
void populateArithToAMDGPUConversionPatterns(
RewritePatternSet &patterns, bool convertFP8Arithmetic,
bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc,
amdgpu::Chipset chipset, PatternBenefit benefit = 1);
} // namespace arith
} // namespace mlir

Expand Down
9 changes: 5 additions & 4 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,8 +690,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,

void mlir::arith::populateArithToAMDGPUConversionPatterns(
RewritePatternSet &patterns, bool convertFP8Arithmetic,
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset,
PatternBenefit benefit) {
bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc,
Chipset chipset, PatternBenefit benefit) {

if (convertFP8Arithmetic) {
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset,
Expand All @@ -702,7 +702,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
if (allowPackedF16Rtz)
patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext(), benefit);

if (chipset >= kGfx950) {
if (supportsScaledExtTrunc) {
patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), benefit);
patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), benefit);
}
Expand All @@ -720,9 +720,10 @@ void ArithToAMDGPUConversionPass::runOnOperation() {

bool convertFP8Arithmetic =
*maybeChipset == kGfx942 || hasOcpFp8(*maybeChipset);
bool supportsScaledExtTrunc = *maybeChipset == kGfx950;
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
supportsScaledExtTrunc, *maybeChipset);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return signalPassFailure();
}