@@ -102,20 +102,23 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
102102 return rewriter.replaceOp (op, result);
103103 }
104104 int64_t numElements = inType.getNumElements ();
105+
105106 Value zero = rewriter.create <arith::ConstantOp>(
106107 loc, outElemType, rewriter.getFloatAttr (outElemType, 0.0 ));
108+ VectorType outType = cast<VectorType>(op.getOut ().getType ());
109+
107110 if (inType.getShape ().empty ()) {
111+ Value zerodSplat =
112+ rewriter.createOrFold <vector::SplatOp>(loc, outType, zero);
108113 Value scalarIn =
109114 rewriter.create <vector::ExtractOp>(loc, in, ArrayRef<int64_t >{});
110- // Recurse to send the 0-D vector case to the 1-D vector case
111115 Value scalarExt =
112116 rewriter.create <arith::ExtFOp>(loc, outElemType, scalarIn);
113- Value result = rewriter.create <vector::InsertOp>(loc, scalarExt, zero ,
117+ Value result = rewriter.create <vector::InsertOp>(loc, scalarExt, zerodSplat ,
114118 ArrayRef<int64_t >{});
115119 return rewriter.replaceOp (op, result);
116120 }
117121
118- VectorType outType = cast<VectorType>(op.getOut ().getType ());
119122 VectorType flatTy = VectorType::get (SmallVector<int64_t >{numElements},
120123 outType.getElementType ());
121124 Value result = rewriter.createOrFold <vector::SplatOp>(loc, flatTy, zero);
0 commit comments