@@ -35,6 +35,14 @@ static Value createConst(Location loc, Type type, int value,
3535 return rewriter.create <arith::ConstantOp>(loc, attr);
3636}
3737
38+ // / Creates shapedType using shape from cloneFrom and base type from cloneTo
39+ static Type cloneToShapedType (Type cloneFrom, Type cloneTo) {
40+ if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
41+ return shapedTy.clone (cloneTo);
42+ }
43+ return cloneTo;
44+ }
45+
3846namespace {
3947
4048// / Expands CeilDivUIOp (n, m) into
@@ -225,12 +233,8 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
225233 return rewriter.notifyMatchFailure (op, " not a ext of bf16 to f32." );
226234 }
227235
228- Type i16Ty = b.getI16Type ();
229- Type i32Ty = b.getI32Type ();
230- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
231- i16Ty = shapedTy.clone (i16Ty);
232- i32Ty = shapedTy.clone (i32Ty);
233- }
236+ Type i16Ty = cloneToShapedType (operandTy, b.getI16Type ());
237+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
234238
235239 Value bitcast = b.create <arith::BitcastOp>(i16Ty, operand);
236240 Value exti = b.create <arith::ExtUIOp>(i32Ty, bitcast);
@@ -264,14 +268,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
264268 op, " only applicable to default rounding mode." );
265269 }
266270
267- Type i16Ty = b.getI16Type ();
268- Type i32Ty = b.getI32Type ();
269- Type f32Ty = b.getF32Type ();
270- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
271- i16Ty = shapedTy.clone (i16Ty);
272- i32Ty = shapedTy.clone (i32Ty);
273- f32Ty = shapedTy.clone (f32Ty);
274- }
271+ Type i16Ty = cloneToShapedType (operandTy, b.getI16Type ());
272+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
275273
276274 // Algorithm borrowed from this excellent code:
277275 // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
@@ -291,7 +289,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
291289 // Constant used to make the rounding bias.
292290 Value c7FFF = createConst (op.getLoc (), i32Ty, 0x7fff , rewriter);
293291 // Constant used to generate a quiet NaN.
294- Value c7FC0_i16 = createConst (op.getLoc (), i16Ty, 0x7fc0 , rewriter);
292+ Value c7FC0I16 = createConst (op.getLoc (), i16Ty, 0x7fc0 , rewriter);
295293 // Small constants used to address bits.
296294 Value c16 = createConst (op.getLoc (), i32Ty, 16 , rewriter);
297295 Value c1 = createConst (op.getLoc (), i32Ty, 1 , rewriter);
@@ -313,18 +311,104 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
313311 // Now that the rounding-bias has been added, truncating the low bits
314312 // yields the correctly rounded result.
315313 Value biasedAndShifted = b.create <arith::ShRUIOp>(biased, c16);
316- Value normalCaseResult_i16 =
314+ Value normalCaseResultI16 =
317315 b.create <arith::TruncIOp>(i16Ty, biasedAndShifted);
318316 // Select either the above-computed result, or a quiet NaN constant
319317 // if the input was NaN.
320318 Value select =
321- b.create <arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16 );
319+ b.create <arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16 );
322320 Value result = b.create <arith::BitcastOp>(resultTy, select);
323321 rewriter.replaceOp (op, result);
324322 return success ();
325323 }
326324};
327325
326+ struct F8E8M0ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
327+ using OpRewritePattern::OpRewritePattern;
328+ LogicalResult matchAndRewrite (arith::ExtFOp op,
329+ PatternRewriter &rewriter) const final {
330+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
331+ Value operand = op.getOperand ();
332+ Type operandTy = operand.getType ();
333+ Type resultTy = op.getType ();
334+ Type operandETy = getElementTypeOrSelf (operandTy);
335+ Type resultETy = getElementTypeOrSelf (resultTy);
336+
337+ if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
338+ return rewriter.notifyMatchFailure (op, " not a ext of F8E8M0FNU" );
339+ }
340+
341+ Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
342+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
343+ Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
344+
345+ Value bitcast = b.create <arith::BitcastOp>(i8Ty, operand);
346+ // create constants for NaNs
347+ Value cF8NaN = createConst (op.getLoc (), i8Ty, 0xff , rewriter);
348+ Value cF32NaN = createConst (op.getLoc (), i32Ty, 0xffffffff , rewriter);
349+ Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
350+
351+ Value exti = b.create <arith::ExtUIOp>(i32Ty, bitcast);
352+ Value f32Bits = b.create <arith::ShLIOp>(exti, cF32MantissaWidth);
353+
354+ Value isNan =
355+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
356+ // select for NaNs
357+ f32Bits = b.create <arith::SelectOp>(isNan, cF32NaN, f32Bits);
358+ Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
359+ if (resultETy.getIntOrFloatBitWidth () < 32 ) {
360+ result = b.create <arith::TruncFOp>(resultTy, result);
361+ } else if (resultETy.getIntOrFloatBitWidth () > 32 ) {
362+ result = b.create <arith::ExtFOp>(resultTy, result);
363+ }
364+ rewriter.replaceOp (op, result);
365+ return success ();
366+ }
367+ };
368+
369+ /*
370+ TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
371+ Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
372+ they all map to NaN in F8E8M0 Type.
373+ */
374+ struct F8E8M0TruncFOpConverter : public OpRewritePattern <arith::TruncFOp> {
375+ using OpRewritePattern::OpRewritePattern;
376+ LogicalResult matchAndRewrite (arith::TruncFOp op,
377+ PatternRewriter &rewriter) const final {
378+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
379+ Value operand = op.getOperand ();
380+ Type operandTy = operand.getType ();
381+ Type operandETy = getElementTypeOrSelf (operandTy);
382+ Type resultTy = op.getType ();
383+ Type resultETy = getElementTypeOrSelf (resultTy);
384+ if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
385+ return rewriter.notifyMatchFailure (op, " not a truncf to f8E8M0FNU" );
386+ }
387+
388+ if (op.getRoundingmodeAttr ()) {
389+ return rewriter.notifyMatchFailure (
390+ op, " only applicable to default rounding mode." );
391+ }
392+
393+ Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
394+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
395+ Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
396+
397+ if (operandETy.getIntOrFloatBitWidth () < 32 ) {
398+ operand = b.create <arith::ExtFOp>(f32Ty, operand);
399+ } else if (operandETy.getIntOrFloatBitWidth () > 32 ) {
400+ operand = b.create <arith::TruncFOp>(f32Ty, operand);
401+ }
402+ Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operand);
403+ Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
404+ Value f32SignExp = b.create <arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
405+ Value exp8Bits = b.create <arith::TruncIOp>(i8Ty, f32SignExp);
406+ Value result = b.create <arith::BitcastOp>(resultTy, exp8Bits);
407+ rewriter.replaceOp (op, result);
408+ return success ();
409+ }
410+ };
411+
328412struct ArithExpandOpsPass
329413 : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
330414 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -353,20 +437,34 @@ struct ArithExpandOpsPass
353437
354438 if (includeBf16) {
355439 arith::populateExpandBFloat16Patterns (patterns);
356- target.addDynamicallyLegalOp <arith::ExtFOp>(
357- [](arith::ExtFOp op) {
358- Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
359- Type outETy = getElementTypeOrSelf (op.getType ());
360- return !(inETy.isBF16 () && outETy.isF32 ());
361- });
362-
363- target.addDynamicallyLegalOp <arith::TruncFOp>(
364- [](arith::TruncFOp op) {
365- Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
366- Type outETy = getElementTypeOrSelf (op.getType ());
367- return !(inETy.isF32 () && outETy.isBF16 ());
368- });
369440 }
441+ if (includeF8E8M0) {
442+ arith::populateExpandF8E8M0Patterns (patterns);
443+ }
444+
445+ target.addDynamicallyLegalOp <arith::ExtFOp>(
446+ [=](arith::ExtFOp op) {
447+ Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
448+ Type outETy = getElementTypeOrSelf (op.getType ());
449+ bool legalTypes = true ;
450+ if (includeBf16)
451+ legalTypes &= !(inETy.isBF16 () && outETy.isF32 ());
452+ if (includeF8E8M0)
453+ legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
454+ return legalTypes;
455+ });
456+
457+ target.addDynamicallyLegalOp <arith::TruncFOp>(
458+ [=](arith::TruncFOp op) {
459+ Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
460+ Type outETy = getElementTypeOrSelf (op.getType ());
461+ bool legalTypes = true ;
462+ if (includeBf16)
463+ legalTypes &= !(inETy.isF32 () && outETy.isBF16 ());
464+ if (includeF8E8M0)
465+ legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
466+ return legalTypes;
467+ });
370468
371469 // clang-format on
372470 if (failed (applyPartialConversion (getOperation (), target,
@@ -389,6 +487,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
389487 patterns.getContext ());
390488}
391489
490+ void mlir::arith::populateExpandF8E8M0Patterns (RewritePatternSet &patterns) {
491+ patterns.add <F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
492+ patterns.getContext ());
493+ }
494+
392495void mlir::arith::populateArithExpandOpsPatterns (RewritePatternSet &patterns) {
393496 populateCeilFloorDivExpandOpsPatterns (patterns);
394497 // clang-format off
0 commit comments