@@ -35,6 +35,14 @@ static Value createConst(Location loc, Type type, int value,
3535 return rewriter.create <arith::ConstantOp>(loc, attr);
3636}
3737
38+ // / Creates shapedType using shape from cloneFrom and base type from cloneTo
39+ static Type cloneToShapedType (Type cloneFrom, Type cloneTo) {
40+ if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
41+ return shapedTy.clone (cloneTo);
42+ }
43+ return cloneTo;
44+ }
45+
3846namespace {
3947
4048// / Expands CeilDivUIOp (n, m) into
@@ -225,12 +233,8 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
225233 return rewriter.notifyMatchFailure (op, " not a ext of bf16 to f32." );
226234 }
227235
228- Type i16Ty = b.getI16Type ();
229- Type i32Ty = b.getI32Type ();
230- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
231- i16Ty = shapedTy.clone (i16Ty);
232- i32Ty = shapedTy.clone (i32Ty);
233- }
236+ Type i16Ty = cloneToShapedType (operandTy, b.getI16Type ());
237+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
234238
235239 Value bitcast = b.create <arith::BitcastOp>(i16Ty, operand);
236240 Value exti = b.create <arith::ExtUIOp>(i32Ty, bitcast);
@@ -264,14 +268,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
264268 op, " only applicable to default rounding mode." );
265269 }
266270
267- Type i16Ty = b.getI16Type ();
268- Type i32Ty = b.getI32Type ();
269- Type f32Ty = b.getF32Type ();
270- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
271- i16Ty = shapedTy.clone (i16Ty);
272- i32Ty = shapedTy.clone (i32Ty);
273- f32Ty = shapedTy.clone (f32Ty);
274- }
271+ Type i16Ty = cloneToShapedType (operandTy, b.getI16Type ());
272+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
275273
276274 // Algorithm borrowed from this excellent code:
277275 // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
@@ -340,14 +338,9 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
340338 return rewriter.notifyMatchFailure (op, " not a ext of F8E8M0FNU" );
341339 }
342340
343- Type i8Ty = b.getI8Type ();
344- Type i32Ty = b.getI32Type ();
345- Type f32Ty = b.getF32Type ();
346- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
347- i8Ty = shapedTy.clone (i8Ty);
348- i32Ty = shapedTy.clone (i32Ty);
349- f32Ty = shapedTy.clone (f32Ty);
350- }
341+ Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
342+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
343+ Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
351344
352345 Value bitcast = b.create <arith::BitcastOp>(i8Ty, operand);
353346 // create constants for NaNs
@@ -397,14 +390,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
397390 op, " only applicable to default rounding mode." );
398391 }
399392
400- Type i8Ty = b.getI8Type ();
401- Type i32Ty = b.getI32Type ();
402- Type f32Ty = b.getF32Type ();
403- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
404- i8Ty = shapedTy.clone (i8Ty);
405- i32Ty = shapedTy.clone (i32Ty);
406- f32Ty = shapedTy.clone (f32Ty);
407- }
393+ Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
394+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
395+ Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
396+
408397 if (operandETy.getIntOrFloatBitWidth () < 32 ) {
409398 operand = b.create <arith::ExtFOp>(f32Ty, operand);
410399 } else if (operandETy.getIntOrFloatBitWidth () > 32 ) {
0 commit comments