Skip to content

Commit ef07ca9

Browse files
authored
Convert StableHLO reshape to TOSA reshape (#2861)
[tosa.reshape](https://mlir.llvm.org/docs/Dialects/TOSA/#tosareshape-mlirtosareshapeop) accepts shape as a second argument. Conversion of the StableHLO reshape requires insertion of the [tosa.const_shape](https://mlir.llvm.org/docs/Dialects/TOSA/#tosaconst_shape-mlirtosaconstshapeop) op as an argument for the tosa.reshape.
1 parent e1d2f65 commit ef07ca9

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

stablehlo/conversions/tosa/tests/unary.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,11 @@ func.func @while(%arg0: tensor<i32>) -> tensor<i32> {
155155
}) : (tensor<i32>) -> (tensor<i32>)
156156
return %0 : tensor<i32>
157157
}
158+
159+
// CHECK-LABEL: @reshape
160+
func.func @reshape(%arg0 : tensor<2x3xf32>) -> tensor<6xf32> {
161+
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<6> : tensor<1xindex>} : () -> !tosa.shape<1>
162+
// CHECK: tosa.reshape %arg0, %[[VAR0]]
163+
%0 = "stablehlo.reshape"(%arg0) : (tensor<2x3xf32>) -> tensor<6xf32>
164+
return %0 : tensor<6xf32>
165+
}

stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,38 @@ struct ConvertStablehloWhileOp : public OpRewritePattern<stablehlo::WhileOp> {
491491
}
492492
};
493493

494+
struct ConvertStablehloReshapeOp
495+
: public OpRewritePattern<mlir::stablehlo::ReshapeOp> {
496+
using OpRewritePattern<mlir::stablehlo::ReshapeOp>::OpRewritePattern;
497+
498+
LogicalResult matchAndRewrite(stablehlo::ReshapeOp op,
499+
PatternRewriter& rewriter) const override {
500+
auto resultType = op.getResult().getType();
501+
if (!resultType.hasStaticShape()) {
502+
return rewriter.notifyMatchFailure(op, "result tensor must be static");
503+
}
504+
505+
auto resultShape = resultType.getShape();
506+
SmallVector<int64_t, 8> dimensions(resultShape.begin(), resultShape.end());
507+
508+
RankedTensorType shapeTensorType = RankedTensorType::get(
509+
{static_cast<int64_t>(dimensions.size())}, rewriter.getIndexType());
510+
511+
auto denseAttr = DenseIntElementsAttr::get(shapeTensorType, dimensions);
512+
auto shapeType =
513+
tosa::shapeType::get(rewriter.getContext(), dimensions.size());
514+
515+
auto constShapeOp =
516+
rewriter.create<tosa::ConstShapeOp>(op.getLoc(), shapeType, denseAttr);
517+
518+
auto reshapeOp = rewriter.create<tosa::ReshapeOp>(
519+
op.getLoc(), resultType, op.getOperand(), constShapeOp);
520+
521+
rewriter.replaceOp(op, reshapeOp.getResult());
522+
return success();
523+
}
524+
};
525+
494526
LogicalResult StablehloLegalizeToTosaPass::initialize(MLIRContext* ctx) {
495527
RewritePatternSet patternList(ctx);
496528
populateGeneratedPDLLPatterns(patternList);
@@ -509,6 +541,9 @@ LogicalResult StablehloLegalizeToTosaPass::initialize(MLIRContext* ctx) {
509541
patternList.addWithLabel<ConvertStablehloTransposeOp>({"StablehloTranspose"},
510542
ctx);
511543
patternList.addWithLabel<ConvertStablehloWhileOp>({"StablehloWhile"}, ctx);
544+
patternList.addWithLabel<ConvertStablehloReshapeOp>({"StablehloReshape"},
545+
ctx);
546+
512547
patterns = std::move(patternList);
513548
return success();
514549
}

0 commit comments

Comments
 (0)