@@ -256,17 +256,25 @@ struct ConvertShapeOfOpPattern : public OpRewritePattern<shape::ShapeOfOp> {
256256 // Produce a StableHLO equivalent of this shape::ShapeOfOp.
257257 // This is a very laborious representation because StableHLO is currently
258258 // lacking convenient tools to express this.
259- SmallVector<Value> sizesI32x1;
260- for (auto i = 0 ; i < operandType.getRank (); ++i) {
261- auto sizeI32 =
262- rewriter.create <GetDimensionSizeOp>(op.getLoc (), op.getArg (), i);
263- auto sizeI32x1 = rewriter.create <ReshapeOp>(
264- op.getLoc (), RankedTensorType::get ({1 }, rewriter.getI32Type ()),
265- sizeI32);
266- sizesI32x1.push_back (sizeI32x1);
259+ Value shapeI32;
260+ if (operandType.getRank () > 0 ) {
261+ SmallVector<Value> sizesI32x1;
262+ for (auto i = 0 ; i < operandType.getRank (); ++i) {
263+ auto sizeI32 =
264+ rewriter.create <GetDimensionSizeOp>(op.getLoc (), op.getArg (), i);
265+ auto sizeI32x1 = rewriter.create <ReshapeOp>(
266+ op.getLoc (), RankedTensorType::get ({1 }, rewriter.getI32Type ()),
267+ sizeI32);
268+ sizesI32x1.push_back (sizeI32x1);
269+ }
270+ shapeI32 = rewriter.create <ConcatenateOp>(op.getLoc (), sizesI32x1,
271+ /* dimension=*/ 0 );
272+ } else {
273+ shapeI32 = rewriter.create <ConstantOp>(
274+ op.getLoc (), DenseElementsAttr::get (
275+ RankedTensorType::get ({0 }, rewriter.getI32Type ()),
276+ ArrayRef<Attribute>()));
267277 }
268- auto shapeI32 = rewriter.create <ConcatenateOp>(op.getLoc (), sizesI32x1,
269- /* dimension=*/ 0 );
270278
271279 // Cast result from tensor<Nxi32> to tensor<Nxindex>.
272280 // This will error out if the result is !shape.shape.
0 commit comments