@@ -35,6 +35,14 @@ static Value createConst(Location loc, Type type, int value,
3535  return  rewriter.create <arith::ConstantOp>(loc, attr);
3636}
3737
38+ // / Creates shapedType using shape from cloneFrom and base type from cloneTo
39+ static  Type cloneToShapedType (Type cloneFrom, Type cloneTo) {
40+   if  (auto  shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
41+     return  shapedTy.clone (cloneTo);
42+   }
43+   return  cloneTo;
44+ }
45+ 
3846namespace  {
3947
4048// / Expands CeilDivUIOp (n, m) into
@@ -225,12 +233,8 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
225233      return  rewriter.notifyMatchFailure (op, " not a ext of bf16 to f32."  );
226234    }
227235
228-     Type i16Ty = b.getI16Type ();
229-     Type i32Ty = b.getI32Type ();
230-     if  (auto  shapedTy = dyn_cast<ShapedType>(operandTy)) {
231-       i16Ty = shapedTy.clone (i16Ty);
232-       i32Ty = shapedTy.clone (i32Ty);
233-     }
236+     Type i16Ty = cloneToShapedType (operandTy, b.getI16Type ());
237+     Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
234238
235239    Value bitcast = b.create <arith::BitcastOp>(i16Ty, operand);
236240    Value exti = b.create <arith::ExtUIOp>(i32Ty, bitcast);
@@ -264,14 +268,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
264268          op, " only applicable to default rounding mode."  );
265269    }
266270
267-     Type i16Ty = b.getI16Type ();
268-     Type i32Ty = b.getI32Type ();
269-     Type f32Ty = b.getF32Type ();
270-     if  (auto  shapedTy = dyn_cast<ShapedType>(operandTy)) {
271-       i16Ty = shapedTy.clone (i16Ty);
272-       i32Ty = shapedTy.clone (i32Ty);
273-       f32Ty = shapedTy.clone (f32Ty);
274-     }
271+     Type i16Ty = cloneToShapedType (operandTy, b.getI16Type ());
272+     Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
275273
276274    //  Algorithm borrowed from this excellent code:
277275    //  https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
@@ -291,7 +289,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
291289    //  Constant used to make the rounding bias.
292290    Value c7FFF = createConst (op.getLoc (), i32Ty, 0x7fff , rewriter);
293291    //  Constant used to generate a quiet NaN.
294-     Value c7FC0_i16  = createConst (op.getLoc (), i16Ty, 0x7fc0 , rewriter);
292+     Value c7FC0I16  = createConst (op.getLoc (), i16Ty, 0x7fc0 , rewriter);
295293    //  Small constants used to address bits.
296294    Value c16 = createConst (op.getLoc (), i32Ty, 16 , rewriter);
297295    Value c1 = createConst (op.getLoc (), i32Ty, 1 , rewriter);
@@ -313,18 +311,104 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
313311    //  Now that the rounding-bias has been added, truncating the low bits
314312    //  yields the correctly rounded result.
315313    Value biasedAndShifted = b.create <arith::ShRUIOp>(biased, c16);
316-     Value normalCaseResult_i16  =
314+     Value normalCaseResultI16  =
317315        b.create <arith::TruncIOp>(i16Ty, biasedAndShifted);
318316    //  Select either the above-computed result, or a quiet NaN constant
319317    //  if the input was NaN.
320318    Value select =
321-         b.create <arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16 );
319+         b.create <arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16 );
322320    Value result = b.create <arith::BitcastOp>(resultTy, select);
323321    rewriter.replaceOp (op, result);
324322    return  success ();
325323  }
326324};
327325
326+ struct  F8E8M0ExtFOpConverter  : public  OpRewritePattern <arith::ExtFOp> {
327+   using  OpRewritePattern::OpRewritePattern;
328+   LogicalResult matchAndRewrite (arith::ExtFOp op,
329+                                 PatternRewriter &rewriter) const  final  {
330+     ImplicitLocOpBuilder b (op.getLoc (), rewriter);
331+     Value operand = op.getOperand ();
332+     Type operandTy = operand.getType ();
333+     Type resultTy = op.getType ();
334+     Type operandETy = getElementTypeOrSelf (operandTy);
335+     Type resultETy = getElementTypeOrSelf (resultTy);
336+ 
337+     if  (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
338+       return  rewriter.notifyMatchFailure (op, " not a ext of F8E8M0FNU"  );
339+     }
340+ 
341+     Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
342+     Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
343+     Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
344+ 
345+     Value bitcast = b.create <arith::BitcastOp>(i8Ty, operand);
346+     //  create constants for NaNs
347+     Value cF8NaN = createConst (op.getLoc (), i8Ty, 0xff , rewriter);
348+     Value cF32NaN = createConst (op.getLoc (), i32Ty, 0xffffffff , rewriter);
349+     Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
350+ 
351+     Value exti = b.create <arith::ExtUIOp>(i32Ty, bitcast);
352+     Value f32Bits = b.create <arith::ShLIOp>(exti, cF32MantissaWidth);
353+ 
354+     Value isNan =
355+         b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
356+     //  select for NaNs
357+     f32Bits = b.create <arith::SelectOp>(isNan, cF32NaN, f32Bits);
358+     Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
359+     if  (resultETy.getIntOrFloatBitWidth () < 32 ) {
360+       result = b.create <arith::TruncFOp>(resultTy, result);
361+     } else  if  (resultETy.getIntOrFloatBitWidth () > 32 ) {
362+       result = b.create <arith::ExtFOp>(resultTy, result);
363+     }
364+     rewriter.replaceOp (op, result);
365+     return  success ();
366+   }
367+ };
368+ 
369+ /* 
370+ TruncF to F8E8M0 is expected to extract exponent bits out of F32 type 
371+ Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type, 
372+ they all map to NaN in F8E8M0 Type. 
373+ */ 
374+ struct  F8E8M0TruncFOpConverter  : public  OpRewritePattern <arith::TruncFOp> {
375+   using  OpRewritePattern::OpRewritePattern;
376+   LogicalResult matchAndRewrite (arith::TruncFOp op,
377+                                 PatternRewriter &rewriter) const  final  {
378+     ImplicitLocOpBuilder b (op.getLoc (), rewriter);
379+     Value operand = op.getOperand ();
380+     Type operandTy = operand.getType ();
381+     Type operandETy = getElementTypeOrSelf (operandTy);
382+     Type resultTy = op.getType ();
383+     Type resultETy = getElementTypeOrSelf (resultTy);
384+     if  (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
385+       return  rewriter.notifyMatchFailure (op, " not a truncf to f8E8M0FNU"  );
386+     }
387+ 
388+     if  (op.getRoundingmodeAttr ()) {
389+       return  rewriter.notifyMatchFailure (
390+           op, " only applicable to default rounding mode."  );
391+     }
392+ 
393+     Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
394+     Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
395+     Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
396+ 
397+     if  (operandETy.getIntOrFloatBitWidth () < 32 ) {
398+       operand = b.create <arith::ExtFOp>(f32Ty, operand);
399+     } else  if  (operandETy.getIntOrFloatBitWidth () > 32 ) {
400+       operand = b.create <arith::TruncFOp>(f32Ty, operand);
401+     }
402+     Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operand);
403+     Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
404+     Value f32SignExp = b.create <arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
405+     Value exp8Bits = b.create <arith::TruncIOp>(i8Ty, f32SignExp);
406+     Value result = b.create <arith::BitcastOp>(resultTy, exp8Bits);
407+     rewriter.replaceOp (op, result);
408+     return  success ();
409+   }
410+ };
411+ 
328412struct  ArithExpandOpsPass 
329413    : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
330414  using  ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -353,20 +437,34 @@ struct ArithExpandOpsPass
353437
354438    if  (includeBf16) {
355439      arith::populateExpandBFloat16Patterns (patterns);
356-       target.addDynamicallyLegalOp <arith::ExtFOp>(
357-         [](arith::ExtFOp op) {
358-           Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
359-           Type outETy = getElementTypeOrSelf (op.getType ());
360-           return  !(inETy.isBF16 () && outETy.isF32 ());
361-         });
362- 
363-       target.addDynamicallyLegalOp <arith::TruncFOp>(
364-         [](arith::TruncFOp op)  {
365-           Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
366-           Type outETy = getElementTypeOrSelf (op.getType ());
367-           return  !(inETy.isF32 () && outETy.isBF16 ());
368-         });
369440    }
441+     if  (includeF8E8M0) {
442+       arith::populateExpandF8E8M0Patterns (patterns);
443+     }
444+ 
445+     target.addDynamicallyLegalOp <arith::ExtFOp>(
446+       [=](arith::ExtFOp op) {
447+         Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
448+         Type outETy = getElementTypeOrSelf (op.getType ());
449+         bool  legalTypes = true ;
450+         if  (includeBf16) 
451+           legalTypes &= !(inETy.isBF16 () && outETy.isF32 ());
452+         if  (includeF8E8M0)
453+           legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
454+         return  legalTypes;
455+       });
456+ 
457+     target.addDynamicallyLegalOp <arith::TruncFOp>(
458+       [=](arith::TruncFOp op)  {
459+         Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
460+         Type outETy = getElementTypeOrSelf (op.getType ());
461+         bool  legalTypes = true ;
462+         if  (includeBf16) 
463+           legalTypes &= !(inETy.isF32 () && outETy.isBF16 ());
464+         if  (includeF8E8M0) 
465+           legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy)); 
466+         return  legalTypes;
467+       });
370468
371469    //  clang-format on
372470    if  (failed (applyPartialConversion (getOperation (), target,
@@ -389,6 +487,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
389487      patterns.getContext ());
390488}
391489
490+ void  mlir::arith::populateExpandF8E8M0Patterns (RewritePatternSet &patterns) {
491+   patterns.add <F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
492+       patterns.getContext ());
493+ }
494+ 
392495void  mlir::arith::populateArithExpandOpsPatterns (RewritePatternSet &patterns) {
393496  populateCeilFloorDivExpandOpsPatterns (patterns);
394497  //  clang-format off
0 commit comments