@@ -291,7 +291,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
291291 // Constant used to make the rounding bias.
292292 Value c7FFF = createConst (op.getLoc (), i32Ty, 0x7fff , rewriter);
293293 // Constant used to generate a quiet NaN.
294- Value c7FC0_i16 = createConst (op.getLoc (), i16Ty, 0x7fc0 , rewriter);
294+ Value c7FC0I16 = createConst (op.getLoc (), i16Ty, 0x7fc0 , rewriter);
295295 // Small constants used to address bits.
296296 Value c16 = createConst (op.getLoc (), i32Ty, 16 , rewriter);
297297 Value c1 = createConst (op.getLoc (), i32Ty, 1 , rewriter);
@@ -313,18 +313,120 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
313313 // Now that the rounding-bias has been added, truncating the low bits
314314 // yields the correctly rounded result.
315315 Value biasedAndShifted = b.create <arith::ShRUIOp>(biased, c16);
316- Value normalCaseResult_i16 =
316+ Value normalCaseResultI16 =
317317 b.create <arith::TruncIOp>(i16Ty, biasedAndShifted);
318318 // Select either the above-computed result, or a quiet NaN constant
319319 // if the input was NaN.
320320 Value select =
321- b.create <arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16 );
321+ b.create <arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16 );
322322 Value result = b.create <arith::BitcastOp>(resultTy, select);
323323 rewriter.replaceOp (op, result);
324324 return success ();
325325 }
326326};
327327
328+ struct F8E8M0ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
329+ using OpRewritePattern::OpRewritePattern;
330+ LogicalResult matchAndRewrite (arith::ExtFOp op,
331+ PatternRewriter &rewriter) const final {
332+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
333+ auto operand = op.getOperand ();
334+ Type operandTy = operand.getType ();
335+ Type resultTy = op.getType ();
336+ Type operandETy = getElementTypeOrSelf (operandTy);
337+ Type resultETy = getElementTypeOrSelf (resultTy);
338+
339+ if (!operandETy.isF8E8M0FNU ()) {
340+ return rewriter.notifyMatchFailure (op, " not a ext of F8E8M0FNU" );
341+ }
342+
343+ if (!resultETy.isBF16 () && !resultETy.isF16 () && !resultETy.isF32 ()) {
344+ return rewriter.notifyMatchFailure (
345+ op, " not a ext of F8M0FNU on a larger 16-bit or 32-bit width float." );
346+ }
347+
348+ Type i8Ty = b.getI8Type ();
349+ Type i32Ty = b.getI32Type ();
350+ Type f32Ty = b.getF32Type ();
351+ if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
352+ i8Ty = shapedTy.clone (i8Ty);
353+ i32Ty = shapedTy.clone (i32Ty);
354+ f32Ty = shapedTy.clone (f32Ty);
355+ }
356+
357+ Value bitcast = b.create <arith::BitcastOp>(i8Ty, operand);
358+ // create constants for NaNs
359+ Value cF8NaN = createConst (op.getLoc (), i8Ty, 0xff , rewriter);
360+ Value cF32NaN = createConst (op.getLoc (), i32Ty, 0xffffffff , rewriter);
361+ Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
362+
363+ Value exti = b.create <arith::ExtUIOp>(i32Ty, bitcast);
364+ Value f32Bits = b.create <arith::ShLIOp>(exti, cF32MantissaWidth);
365+
366+ Value isNan =
367+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
368+ // select for NaNs
369+ f32Bits = b.create <arith::SelectOp>(isNan, cF32NaN, f32Bits);
370+ Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
371+ if (resultETy.isBF16 ()) {
372+ result = b.create <arith::TruncFOp>(resultTy, result);
373+ } else if (resultETy.isF16 ()) {
374+ result = b.create <arith::TruncFOp>(resultTy, result);
375+ }
376+ rewriter.replaceOp (op, result);
377+ return success ();
378+ }
379+ };
380+
381+ /*
382+ TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
383+ Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
384+ they all map to NaN in F8E8M0 Type.
385+ */
386+ struct F8E8M0TruncFOpConverter : public OpRewritePattern <arith::TruncFOp> {
387+ using OpRewritePattern::OpRewritePattern;
388+ LogicalResult matchAndRewrite (arith::TruncFOp op,
389+ PatternRewriter &rewriter) const final {
390+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
391+ auto operand = op.getOperand ();
392+ Type operandTy = operand.getType ();
393+ Type operandETy = getElementTypeOrSelf (operandTy);
394+ Type resultTy = op.getType ();
395+ Type resultETy = getElementTypeOrSelf (resultTy);
396+ if (!resultETy.isF8E8M0FNU ()) {
397+ return rewriter.notifyMatchFailure (op, " not a truncf to f8E8M0FNU" );
398+ }
399+ if (!operandETy.isBF16 () && !operandETy.isF16 () && !operandETy.isF32 ()) {
400+ return rewriter.notifyMatchFailure (
401+ op, " not a truncf of 16-bit or 32-bit float to f8E8M0FNU." );
402+ }
403+
404+ if (op.getRoundingmodeAttr ()) {
405+ return rewriter.notifyMatchFailure (
406+ op, " only applicable to default rounding mode." );
407+ }
408+
409+ Type i8Ty = b.getI8Type ();
410+ Type i32Ty = b.getI32Type ();
411+ Type f32Ty = b.getF32Type ();
412+ if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
413+ i8Ty = shapedTy.clone (i8Ty);
414+ i32Ty = shapedTy.clone (i32Ty);
415+ f32Ty = shapedTy.clone (f32Ty);
416+ }
417+ if (!operandETy.isF32 ()) {
418+ operand = b.create <arith::ExtFOp>(f32Ty, operand);
419+ }
420+ Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operand);
421+ Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
422+ Value f32SignExp = b.create <arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
423+ Value exp8Bits = b.create <arith::TruncIOp>(i8Ty, f32SignExp);
424+ Value result = b.create <arith::BitcastOp>(resultTy, exp8Bits);
425+ rewriter.replaceOp (op, result);
426+ return success ();
427+ }
428+ };
429+
328430struct ArithExpandOpsPass
329431 : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
330432 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -351,23 +453,36 @@ struct ArithExpandOpsPass
351453 arith::MinNumFOp
352454 >();
353455
354- if (includeBf16) {
456+ if (includeBf16) {
355457 arith::populateExpandBFloat16Patterns (patterns);
458+ }
459+ if (includeF8E8M0) {
460+ arith::populateExpandF8E8M0Patterns (patterns);
461+ }
462+ if (includeBf16 || includeF8E8M0) {
356463 target.addDynamicallyLegalOp <arith::ExtFOp>(
357- [](arith::ExtFOp op) {
464+ [= ](arith::ExtFOp op) {
358465 Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
359466 Type outETy = getElementTypeOrSelf (op.getType ());
360- return !(inETy.isBF16 () && outETy.isF32 ());
467+ if (includeBf16 && includeF8E8M0)
468+ return !(inETy.isBF16 () && outETy.isF32 ()) && !(inETy.isF8E8M0FNU () && (outETy.isF32 () || outETy.isBF16 () || outETy.isF16 ()));
469+ if (includeBf16)
470+ return !(inETy.isBF16 () && outETy.isF32 ());
471+ return !(inETy.isF8E8M0FNU () && (outETy.isF32 () || outETy.isBF16 () || outETy.isF16 ()));
361472 });
362473
363474 target.addDynamicallyLegalOp <arith::TruncFOp>(
364- [](arith::TruncFOp op) {
475+ [= ](arith::TruncFOp op) {
365476 Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
366477 Type outETy = getElementTypeOrSelf (op.getType ());
367- return !(inETy.isF32 () && outETy.isBF16 ());
478+ if (includeBf16 && includeF8E8M0)
479+ return !(inETy.isF32 () && outETy.isBF16 ()) && !(outETy.isF8E8M0FNU () && (inETy.isF32 () || inETy.isF16 () || inETy.isBF16 ()));
480+ if (includeBf16)
481+ return !(inETy.isF32 () && outETy.isBF16 ());
482+ return
483+ !(outETy.isF8E8M0FNU () && (inETy.isF32 () || inETy.isF16 () || inETy.isBF16 ()));
368484 });
369485 }
370-
371486 // clang-format on
372487 if (failed (applyPartialConversion (getOperation (), target,
373488 std::move (patterns))))
@@ -389,6 +504,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
389504 patterns.getContext ());
390505}
391506
507+ void mlir::arith::populateExpandF8E8M0Patterns (RewritePatternSet &patterns) {
508+ patterns.add <F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
509+ patterns.getContext ());
510+ }
511+
392512void mlir::arith::populateArithExpandOpsPatterns (RewritePatternSet &patterns) {
393513 populateCeilFloorDivExpandOpsPatterns (patterns);
394514 // clang-format off
0 commit comments