Skip to content

Commit ec4ec78

Browse files
authored
Skip conversion of shape.shapeof with 0-ranked tensor operand (#2107)
Currently `--shape-legalize-to-stablehlo` fails on the following code: ``` func.func @test1() -> tensor<0xindex> { %0 = arith.constant dense<0> : tensor<i32> %1 = shape.shape_of %0 : tensor<i32> -> tensor<0xindex> func.return %1 : tensor<0xindex> } ``` at `stablehlo/dialect/TypeInference.cpp:1700`: ``` // concatenate_c5 auto elementType = inputTypes[0].cast<ShapedType>().getElementType(); ``` as `inputTypes.size()` is zero. I have checked how it works on non-0 ranked tensor type: ``` func.func @test2() -> tensor<2xindex> { %1 = arith.constant dense<0> : tensor<2x128xi32> %3 = shape.shape_of %1 : tensor<2x128xi32> -> tensor<2xindex> func.return %3 : tensor<2xindex> } ``` produces: ``` func.func @test2() -> tensor<2xindex> { %cst = arith.constant dense<0> : tensor<2x128xi32> %0 = stablehlo.get_dimension_size %cst, dim = 0 : (tensor<2x128xi32>) -> tensor<i32> %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32> %2 = stablehlo.get_dimension_size %cst, dim = 1 : (tensor<2x128xi32>) -> tensor<i32> %3 = stablehlo.reshape %2 : (tensor<i32>) -> tensor<1xi32> %4 = stablehlo.concatenate %1, %3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %5 = builtin.unrealized_conversion_cast %4 : tensor<2xi32> to tensor<2xindex> return %5 : tensor<2xindex> } ``` I suggest considering an alternative approach instead of simply bailing out; one option could be generating a constant tensor with zero dimension: ``` func.func @test1() -> tensor<0xindex> { %cst = arith.constant dense<0> : tensor<i32> %0 = stablehlo.constant dense<> : tensor<0xi32> %1 = builtin.unrealized_conversion_cast %0 : tensor<0xi32> to tensor<0xindex> return %1 : tensor<0xindex> } ``` but i am not entirely certain in semantic equivalence.
1 parent 9c1bccf commit ec4ec78

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

stablehlo/tests/shape_legalize_to_stablehlo.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,3 +387,15 @@ func.func @tensor_extract_dynamic(%arg0: tensor<?x3xindex>) -> index {
387387
%0 = tensor.extract %arg0[%c1, %c2] : tensor<?x3xindex>
388388
return %0 : index
389389
}
390+
391+
// -----
392+
393+
// CHECK-LABEL: func @shape_of_zero_ranked_tensor
394+
func.func @shape_of_zero_ranked_tensor(%arg0: tensor<?x3xindex>) -> tensor<0xindex> {
395+
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<> : tensor<0xi32>
396+
// CHECK-NEXT: %[[RES_DIM0_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CONST]] : tensor<0xi32> to tensor<0xindex>
397+
// CHECK-NEXT: return %[[RES_DIM0_INDEX]] : tensor<0xindex>
398+
%0 = arith.constant dense<0> : tensor<i32>
399+
%1 = shape.shape_of %0 : tensor<i32> -> tensor<0xindex>
400+
func.return %1 : tensor<0xindex>
401+
}

stablehlo/transforms/ShapeLegalizeToStablehlo.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -256,17 +256,25 @@ struct ConvertShapeOfOpPattern : public OpRewritePattern<shape::ShapeOfOp> {
256256
// Produce a StableHLO equivalent of this shape::ShapeOfOp.
257257
// This is a very laborious representation because StableHLO is currently
258258
// lacking convenient tools to express this.
259-
SmallVector<Value> sizesI32x1;
260-
for (auto i = 0; i < operandType.getRank(); ++i) {
261-
auto sizeI32 =
262-
rewriter.create<GetDimensionSizeOp>(op.getLoc(), op.getArg(), i);
263-
auto sizeI32x1 = rewriter.create<ReshapeOp>(
264-
op.getLoc(), RankedTensorType::get({1}, rewriter.getI32Type()),
265-
sizeI32);
266-
sizesI32x1.push_back(sizeI32x1);
259+
Value shapeI32;
260+
if (operandType.getRank() > 0) {
261+
SmallVector<Value> sizesI32x1;
262+
for (auto i = 0; i < operandType.getRank(); ++i) {
263+
auto sizeI32 =
264+
rewriter.create<GetDimensionSizeOp>(op.getLoc(), op.getArg(), i);
265+
auto sizeI32x1 = rewriter.create<ReshapeOp>(
266+
op.getLoc(), RankedTensorType::get({1}, rewriter.getI32Type()),
267+
sizeI32);
268+
sizesI32x1.push_back(sizeI32x1);
269+
}
270+
shapeI32 = rewriter.create<ConcatenateOp>(op.getLoc(), sizesI32x1,
271+
/*dimension=*/0);
272+
} else {
273+
shapeI32 = rewriter.create<ConstantOp>(
274+
op.getLoc(), DenseElementsAttr::get(
275+
RankedTensorType::get({0}, rewriter.getI32Type()),
276+
ArrayRef<Attribute>()));
267277
}
268-
auto shapeI32 = rewriter.create<ConcatenateOp>(op.getLoc(), sizesI32x1,
269-
/*dimension=*/0);
270278

271279
// Cast result from tensor<Nxi32> to tensor<Nxindex>.
272280
// This will error out if the result is !shape.shape.

0 commit comments

Comments
 (0)