Skip to content

Commit 3442dbe

Browse files
authored
Fix failed reify for ReverseOp (#2109)
The `ReverseOp` did not implement a `reifyReturnTypeShapes`. This resulted in a failed assertion in the `getEmptyTensorFor` helper function that gets used during the `stablehlo-legalize-to-linalg` pass when the op has dynamic input/output. This patch implements this function as well as adds a test for the newly supported case (input can have dynamic dims as long as they don't correspond to the `dims` attribute of the op) Various other quality of life changes, such as having match failure notifications have been added.
1 parent ec4ec78 commit 3442dbe

File tree

6 files changed

+45
-6
lines changed

6 files changed

+45
-6
lines changed

stablehlo/conversions/linalg/tests/miscellaneous.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,21 @@ func.func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
10721072

10731073
// -----
10741074

1075+
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
1076+
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
1077+
// CHECK: func @reverse_dynamic
1078+
func.func @reverse_dynamic(%input: tensor<?x3xf32>) -> tensor<?x3xf32> {
1079+
%result = "stablehlo.reverse"(%input) {
1080+
dimensions = array<i64: 1>, someattr
1081+
} : (tensor<?x3xf32>) -> tensor<?x3xf32>
1082+
func.return %result : tensor<?x3xf32>
1083+
}
1084+
// CHECK: linalg.generic
1085+
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
1086+
// CHECK-SAME: {someattr}
1087+
1088+
// -----
1089+
10751090
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 2 + d5, d3)>
10761091
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
10771092
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>

stablehlo/conversions/linalg/transforms/LegalizeToLinalgUtils.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ Value getEmptyTensorFor(OpBuilder &b, Location loc, ShapedType resultType,
7777
// Ask the op for its output shape.
7878
auto shapeSource = cast<InferShapedTypeOpInterface>(op);
7979
SmallVector<Value, 1> reifiedShapes;
80-
(void)shapeSource.reifyReturnTypeShapes(b, operands, reifiedShapes);
80+
assert(succeeded(
81+
shapeSource.reifyReturnTypeShapes(b, operands, reifiedShapes)) &&
82+
"could not reify");
8183
assert(reifiedShapes.size() == 1 && "Expected one reified result");
8284
// Construct sizes for the required dimensions.
8385
for (const auto &en : llvm::enumerate(resultType.getShape())) {

stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,18 +354,20 @@ struct DataMovementOpConverter : OpConversionPattern<OpTy> {
354354
LogicalResult matchAndRewrite(
355355
OpTy op, typename OpTy::Adaptor adaptor,
356356
ConversionPatternRewriter &rewriter) const final {
357-
if (failed(verifyHloOpBufferOrTensorSemantics(op))) return failure();
357+
if (failed(verifyHloOpBufferOrTensorSemantics(op)))
358+
return rewriter.notifyMatchFailure(
359+
op, "failed to verify hlo buffer or tensor semantics");
358360

359361
ShapedType resultType = getHloOpResultType(op);
360362
resultType =
361363
this->getTypeConverter()->template convertType<ShapedType>(resultType);
362-
if (!resultType) {
364+
if (!resultType)
363365
return rewriter.notifyMatchFailure(op, "type conversion failed");
364-
}
365366

366367
SmallVector<AffineMap, 2> indexingMaps =
367368
Derived::getIndexingMaps(op, &rewriter);
368-
if (indexingMaps.empty()) return failure();
369+
if (indexingMaps.empty())
370+
return rewriter.notifyMatchFailure(op, "could not derive indexing maps");
369371

370372
int64_t nloops = resultType.getRank();
371373
Location loc = op.getLoc();

stablehlo/dialect/StablehloOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,6 +1605,13 @@ LogicalResult ReverseOp::inferReturnTypeComponents(
16051605
inferredReturnShapes);
16061606
}
16071607

1608+
LogicalResult ReverseOp::reifyReturnTypeShapes(
1609+
OpBuilder& builder, ValueRange operands,
1610+
SmallVectorImpl<Value>& reifiedReturnShapes) {
1611+
return ::mlir::hlo::deriveShapeFromOperand(
1612+
&builder, getOperation(), operands.front(), &reifiedReturnShapes);
1613+
}
1614+
16081615
//===----------------------------------------------------------------------===//
16091616
// RngBitGeneratorOp
16101617
//===----------------------------------------------------------------------===//

stablehlo/dialect/StablehloOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2730,7 +2730,7 @@ def StableHLO_SortOp : StableHLO_Op<"sort",
27302730
let hasVerifier = 1;
27312731
}
27322732

2733-
def StableHLO_ReverseOp: StableHLO_Op<"reverse",
2733+
def StableHLO_ReverseOp: StableHLO_ShapedInterfaceOp<"reverse",
27342734
[Pure, HLO_CompatibleOperandsAndResultType /*reverse_c1*/]> {
27352735
let summary = "Reverse operation";
27362736
let description = [{

stablehlo/tests/infer_stablehlo.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1826,3 +1826,16 @@ func.func @select(%pred : tensor<i1>,
18261826
%1 = "hlo_test_infer.get_return_types"(%0) : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xindex>
18271827
func.return %1 : tensor<?x?x?x?xindex>
18281828
}
1829+
1830+
// -----
1831+
1832+
// CHECK-LABEL: @reverse
1833+
// CHECK-SAME: %[[A:.*]]: tensor<?x?x?x?xf32>
1834+
func.func @reverse(%a : tensor<?x?x?x?xf32>) -> tensor<4xindex> {
1835+
%0 = "stablehlo.reverse"(%a) {
1836+
dimensions = array<i64: 1, 3>, someattr
1837+
} : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
1838+
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<?x?x?x?xf32> -> tensor<4xindex>
1839+
%1 = "hlo_test_infer.reify_return_type_shapes"(%0) : (tensor<?x?x?x?xf32>) -> tensor<4xindex>
1840+
func.return %1 : tensor<4xindex>
1841+
}

0 commit comments

Comments
 (0)