1010
1111#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1212#include " mlir/Dialect/Arith/IR/Arith.h"
13+ #include " mlir/Dialect/Arith/Utils/Utils.h"
1314#include " mlir/Dialect/Vector/IR/VectorOps.h"
1415#include " mlir/IR/BuiltinTypes.h"
1516#include " mlir/IR/PatternMatch.h"
@@ -34,17 +35,17 @@ struct ArithToAMDGPUConversionPass final
3435 void runOnOperation () override ;
3536};
3637
37- struct ExtfOnFloat8RewritePattern final
38- : public OpRewritePattern<arith::ExtFOp> {
39- using OpRewritePattern<arith::ExtFOp>::OpRewritePattern;
38+ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
39+ using OpRewritePattern::OpRewritePattern;
4040
4141 LogicalResult match (arith::ExtFOp op) const override ;
4242 void rewrite (arith::ExtFOp op, PatternRewriter &rewriter) const override ;
4343};
4444
45- struct TruncfToFloat8RewritePattern final
46- : public OpRewritePattern<arith::TruncFOp> {
47- using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
45+ struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
46+ bool saturateFP8 = false ;
47+ TruncFToFloat8RewritePattern (MLIRContext *ctx, bool saturateFP8)
48+ : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
4849
4950 LogicalResult match (arith::TruncFOp op) const override ;
5051 void rewrite (arith::TruncFOp op, PatternRewriter &rewriter) const override ;
@@ -62,7 +63,7 @@ static Value castF32To(Type elementType, Value f32, Location loc,
6263 llvm_unreachable (" The only 32-bit float type is f32" );
6364}
6465
65- LogicalResult ExtfOnFloat8RewritePattern ::match (arith::ExtFOp op) const {
66+ LogicalResult ExtFOnFloat8RewritePattern ::match (arith::ExtFOp op) const {
6667 Type inType = op.getIn ().getType ();
6768 if (auto inVecType = inType.dyn_cast <VectorType>()) {
6869 if (inVecType.isScalable ())
@@ -75,7 +76,7 @@ LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const {
7576 return success (inType.isFloat8E5M2FNUZ () || inType.isFloat8E4M3FNUZ ());
7677}
7778
78- void ExtfOnFloat8RewritePattern ::rewrite (arith::ExtFOp op,
79+ void ExtFOnFloat8RewritePattern ::rewrite (arith::ExtFOp op,
7980 PatternRewriter &rewriter) const {
8081 Location loc = op.getLoc ();
8182 Value in = op.getIn ();
@@ -93,11 +94,13 @@ void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
9394 Value result =
9495 rewriter.createOrFold <vector::SplatOp>(loc, op.getOut ().getType (), zero);
9596 if (inType.getShape ().empty ()) {
96- Value scalarIn = rewriter.create <vector::ExtractElementOp>(loc, in);
97+ Value scalarIn =
98+ rewriter.create <vector::ExtractOp>(loc, in, ArrayRef<int64_t >{});
9799 // Recurse to send the 0-D vector case to the 1-D vector case
98100 Value scalarExt =
99101 rewriter.create <arith::ExtFOp>(loc, outElemType, scalarIn);
100- result = rewriter.create <vector::InsertElementOp>(loc, scalarExt, zero);
102+ result = rewriter.create <vector::InsertOp>(loc, scalarExt, zero,
103+ ArrayRef<int64_t >{});
101104 return rewriter.replaceOp (op, result);
102105 }
103106 for (int64_t i = 0 ; i < numElements; i += 4 ) {
@@ -108,9 +111,7 @@ void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
108111 Value asFloat = rewriter.create <amdgpu::ExtPackedFp8Op>(
109112 loc, rewriter.getF32Type (), inSlice, j);
110113 Value asType = castF32To (outElemType, asFloat, loc, rewriter);
111- result = rewriter.create <vector::InsertElementOp>(
112- loc, asType, result,
113- rewriter.createOrFold <arith::ConstantIndexOp>(loc, i + j));
114+ result = rewriter.create <vector::InsertOp>(loc, asType, result, i + j);
114115 }
115116 }
116117 rewriter.replaceOp (op, result);
@@ -127,7 +128,53 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
127128 llvm_unreachable (" The only 32-bit float type is f32" );
128129}
129130
130- LogicalResult TruncfToFloat8RewritePattern::match (arith::TruncFOp op) const {
131+ // If `in` is a finite value, clamp it between the maximum and minimum values
132+ // of `outElemType` so that subsequent conversion instructions don't
133+ // overflow those out-of-range values to NaN. These semantics are commonly
134+ // used in machine-learning contexts where failure to clamp would lead to
135+ // excessive NaN production.
136+ static Value clampInput (PatternRewriter &rewriter, Location loc,
137+ Type outElemType, Value source) {
138+ Type sourceType = source.getType ();
139+ const llvm::fltSemantics &sourceSem =
140+ cast<FloatType>(getElementTypeOrSelf (sourceType)).getFloatSemantics ();
141+ const llvm::fltSemantics &targetSem =
142+ cast<FloatType>(outElemType).getFloatSemantics ();
143+
144+ APFloat min = APFloat::getLargest (targetSem, /* Negative=*/ true );
145+ APFloat max = APFloat::getLargest (targetSem, /* Negative=*/ false );
146+ bool ignoredLosesInfo = false ;
147+ // We can ignore conversion failures here because this conversion promotes
148+ // from a smaller type to a larger one - ex. there can be no loss of precision
149+ // when casting fp8 to f16.
150+ (void )min.convert (sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
151+ (void )max.convert (sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
152+
153+ Value minCst = createScalarOrSplatConstant (rewriter, loc, sourceType, min);
154+ Value maxCst = createScalarOrSplatConstant (rewriter, loc, sourceType, max);
155+
156+ Value inf = createScalarOrSplatConstant (
157+ rewriter, loc, sourceType,
158+ APFloat::getInf (sourceSem, /* Negative=*/ false ));
159+ Value negInf = createScalarOrSplatConstant (
160+ rewriter, loc, sourceType, APFloat::getInf (sourceSem, /* Negative=*/ true ));
161+ Value isInf = rewriter.createOrFold <arith::CmpFOp>(
162+ loc, arith::CmpFPredicate::OEQ, source, inf);
163+ Value isNegInf = rewriter.createOrFold <arith::CmpFOp>(
164+ loc, arith::CmpFPredicate::OEQ, source, negInf);
165+ Value isNan = rewriter.createOrFold <arith::CmpFOp>(
166+ loc, arith::CmpFPredicate::UNO, source, source);
167+ Value isNonFinite = rewriter.create <arith::OrIOp>(
168+ loc, rewriter.create <arith::OrIOp>(loc, isInf, isNegInf), isNan);
169+
170+ Value clampedBelow = rewriter.create <arith::MaximumFOp>(loc, source, minCst);
171+ Value clamped = rewriter.create <arith::MinimumFOp>(loc, clampedBelow, maxCst);
172+ Value res =
173+ rewriter.create <arith::SelectOp>(loc, isNonFinite, source, clamped);
174+ return res;
175+ }
176+
177+ LogicalResult TruncFToFloat8RewritePattern::match (arith::TruncFOp op) const {
131178 Type outType = op.getOut ().getType ();
132179 if (auto outVecType = outType.dyn_cast <VectorType>()) {
133180 if (outVecType.isScalable ())
@@ -137,22 +184,27 @@ LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
137184 return failure ();
138185 outType = outVecType.getElementType ();
139186 }
187+ auto inType = dyn_cast<FloatType>(getElementTypeOrSelf (op.getIn ().getType ()));
188+ if (inType && inType.getWidth () <= 8 && saturateFP8)
189+ // Conversion between 8-bit floats is not supported with truncation enabled.
190+ return failure ();
140191 return success (outType.isFloat8E5M2FNUZ () || outType.isFloat8E4M3FNUZ ());
141192}
142193
143- void TruncfToFloat8RewritePattern ::rewrite (arith::TruncFOp op,
194+ void TruncFToFloat8RewritePattern ::rewrite (arith::TruncFOp op,
144195 PatternRewriter &rewriter) const {
145196 Location loc = op.getLoc ();
146197 Value in = op.getIn ();
147198 Type outElemType = getElementTypeOrSelf (op.getOut ().getType ());
199+ if (saturateFP8)
200+ in = clampInput (rewriter, loc, outElemType, in);
148201 VectorType truncResType = VectorType::get (4 , outElemType);
149202 if (!in.getType ().isa <VectorType>()) {
150203 Value asFloat = castToF32 (in, loc, rewriter);
151204 Value asF8s = rewriter.create <amdgpu::PackedTrunc2xFp8Op>(
152205 loc, truncResType, asFloat, /* sourceB=*/ nullptr , 0 ,
153206 /* existing=*/ nullptr );
154- Value result = rewriter.create <vector::ExtractElementOp>(
155- loc, asF8s, rewriter.createOrFold <arith::ConstantIndexOp>(loc, 0 ));
207+ Value result = rewriter.create <vector::ExtractOp>(loc, asF8s, 0 );
156208 return rewriter.replaceOp (op, result);
157209 }
158210 VectorType outType = op.getOut ().getType ().cast <VectorType>();
@@ -161,26 +213,25 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
161213 loc, outElemType, rewriter.getFloatAttr (outElemType, 0.0 ));
162214 Value result = rewriter.createOrFold <vector::SplatOp>(loc, outType, zero);
163215 if (outType.getShape ().empty ()) {
164- Value scalarIn = rewriter.create <vector::ExtractElementOp>(loc, in);
216+ Value scalarIn =
217+ rewriter.create <vector::ExtractOp>(loc, in, ArrayRef<int64_t >{});
165218 // Recurse to send the 0-D vector case to the 1-D vector case
166219 Value scalarTrunc =
167220 rewriter.create <arith::TruncFOp>(loc, outElemType, scalarIn);
168- result = rewriter.create <vector::InsertElementOp>(loc, scalarTrunc, zero);
221+ result = rewriter.create <vector::InsertOp>(loc, scalarTrunc, zero,
222+ ArrayRef<int64_t >{});
169223 return rewriter.replaceOp (op, result);
170224 }
171225
172226 for (int64_t i = 0 ; i < numElements; i += 4 ) {
173227 int64_t elemsThisOp = std::min (numElements, i + 4 ) - i;
174228 Value thisResult = nullptr ;
175229 for (int64_t j = 0 ; j < elemsThisOp; j += 2 ) {
176- Value elemA = rewriter.create <vector::ExtractElementOp>(
177- loc, in, rewriter.create <arith::ConstantIndexOp>(loc, i + j));
230+ Value elemA = rewriter.create <vector::ExtractOp>(loc, in, i + j);
178231 Value asFloatA = castToF32 (elemA, loc, rewriter);
179232 Value asFloatB = nullptr ;
180233 if (j + 1 < elemsThisOp) {
181- Value elemB = rewriter.create <vector::ExtractElementOp>(
182- loc, in,
183- rewriter.createOrFold <arith::ConstantIndexOp>(loc, i + j + 1 ));
234+ Value elemB = rewriter.create <vector::ExtractOp>(loc, in, i + j + 1 );
184235 asFloatB = castToF32 (elemB, loc, rewriter);
185236 }
186237 thisResult = rewriter.create <amdgpu::PackedTrunc2xFp8Op>(
@@ -196,15 +247,16 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
196247}
197248
198249void mlir::arith::populateArithToAMDGPUConversionPatterns (
199- RewritePatternSet &patterns) {
200- patterns.add <ExtfOnFloat8RewritePattern, TruncfToFloat8RewritePattern>(
201- patterns.getContext ());
250+ RewritePatternSet &patterns, bool saturateFP8TruncF) {
251+ patterns.add <ExtFOnFloat8RewritePattern>(patterns.getContext ());
252+ patterns.add <TruncFToFloat8RewritePattern>(patterns.getContext (),
253+ saturateFP8TruncF);
202254}
203255
204256void ArithToAMDGPUConversionPass::runOnOperation () {
205257 Operation *op = getOperation ();
206258 RewritePatternSet patterns (op->getContext ());
207- arith::populateArithToAMDGPUConversionPatterns (patterns);
259+ arith::populateArithToAMDGPUConversionPatterns (patterns, saturateFP8Truncf );
208260 if (failed (applyPatternsAndFoldGreedily (op, std::move (patterns))))
209261 return signalPassFailure ();
210262}
0 commit comments