Skip to content

Commit 1c0547f

Browse files
authored
Support all static tensor shapes in EvalSlice folder (#2557)
This patch enhances the EvalSlice constant folder to support all operator configurations on static tensor shapes, including 0-rank scalars.
1 parent 75bac79 commit 1c0547f

File tree

2 files changed

+81
-42
lines changed

2 files changed

+81
-42
lines changed

stablehlo/tests/transforms/stablehlo_refine_shapes.mlir

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -433,17 +433,25 @@ func.func @eval_sign() -> tensor<3xi64> {
433433
// -----
434434

435435
// CHECK-LABEL: func @eval_slice
436-
func.func @eval_slice() -> tensor<2xi64> {
436+
func.func @eval_slice() -> (tensor<2xi64>, tensor<1x2x1xi64>) {
437437
// CHECK-NOT: stablehlo.slice
438-
// CHECK: [[RESULT:%.*]] = stablehlo.constant dense<[1, 2]> : tensor<2xi64>
439-
// CHECK: return [[RESULT]]
438+
// CHECK: [[RESULT1:%.*]] = stablehlo.constant dense<[1, 2]> : tensor<2xi64>
439+
// CHECK: [[RESULT2:%.*]] = stablehlo.constant dense<{{\[\[}}[15], [19]]]> : tensor<1x2x1xi64>
440+
// CHECK: return [[RESULT1]], [[RESULT2]]
440441
%0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
441442
%1 = "stablehlo.slice"(%0) {
442443
start_indices = array<i64: 0>,
443444
limit_indices = array<i64: 2>,
444445
strides = array<i64: 1>
445446
} : (tensor<4xi64>) -> tensor<2xi64>
446-
func.return %1 : tensor<2xi64>
447+
%2 = stablehlo.constant dense<[[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]],
448+
[[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]]> : tensor<2x3x4xi64>
449+
%3 = "stablehlo.slice"(%2) {
450+
start_indices = array<i64: 0, 1, 1>,
451+
limit_indices = array<i64: 2, 3, 3>,
452+
strides = array<i64: 3, 1, 2>
453+
} : (tensor<2x3x4xi64>) -> tensor<1x2x1xi64>
454+
func.return %1, %3 : tensor<2xi64>, tensor<1x2x1xi64>
447455
}
448456

449457
// -----
@@ -496,18 +504,32 @@ func.func @eval_slice_unit_prefix() -> (tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64>
496504

497505
// -----
498506

499-
// CHECK-LABEL: func @eval_slice_non_unit_prefix
500-
func.func @eval_slice_non_unit_prefix() -> tensor<1x2x1xi64> {
501-
// CHECK: stablehlo.constant {{.*}} : tensor<1x2x2xi64>
502-
// CHECK: [[RESULT:%.*]] = stablehlo.slice{{.*}}
507+
// CHECK-LABEL: func @eval_slice_zerodim
508+
func.func @eval_slice_zerodim() -> tensor<0x2x1xi64> {
509+
// CHECK: [[RESULT:%.*]] = stablehlo.constant dense<> : tensor<0x2x1xi64>
503510
// CHECK: return [[RESULT]]
504511
%0 = stablehlo.constant dense<[[[1, 2], [3, 4]]]> : tensor<1x2x2xi64>
505512
%1 = "stablehlo.slice"(%0) {
506-
start_indices = array<i64: 0, 0, 1>,
513+
start_indices = array<i64: 1, 0, 1>,
507514
limit_indices = array<i64: 1, 2, 2>,
508515
strides = array<i64: 1, 1, 1>
509-
} : (tensor<1x2x2xi64>) -> tensor<1x2x1xi64>
510-
func.return %1 : tensor<1x2x1xi64>
516+
} : (tensor<1x2x2xi64>) -> tensor<0x2x1xi64>
517+
func.return %1 : tensor<0x2x1xi64>
518+
}
519+
520+
// -----
521+
522+
// CHECK-LABEL: func @eval_slice_zerorank
523+
func.func @eval_slice_zerorank() -> tensor<f32> {
524+
// CHECK: [[RESULT:%.*]] = stablehlo.constant dense<3.300000e+01> : tensor<f32>
525+
// CHECK: return [[RESULT]]
526+
%0 = stablehlo.constant dense<33.0> : tensor<f32>
527+
%1 = "stablehlo.slice"(%0) {
528+
start_indices = array<i64>,
529+
limit_indices = array<i64>,
530+
strides = array<i64>
531+
} : (tensor<f32>) -> tensor<f32>
532+
func.return %1 : tensor<f32>
511533
}
512534

513535
// -----

stablehlo/transforms/StablehloAggressiveFolder.cpp

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,41 @@ struct EvalSignOpPattern : public OpRewritePattern<SignOp> {
521521
}
522522
};
523523

524+
template <typename RangeType>
525+
DenseElementsAttr sliceType(SliceOp& op, const RangeType& data) {
526+
using ElementType = std::decay_t<decltype(*std::begin(data))>;
527+
528+
RankedTensorType operandType = op.getOperand().getType();
529+
RankedTensorType resultType = op.getResult().getType();
530+
531+
const auto dimOffsets = computeStrides(operandType.getShape());
532+
auto startIndices = op.getStartIndices();
533+
auto limitIndices = op.getLimitIndices();
534+
auto strides = op.getStrides();
535+
536+
const SmallVector<int64_t> startIndex(startIndices);
537+
const SmallVector<int64_t> endIndex(limitIndices);
538+
539+
SmallVector<ElementType> result;
540+
result.reserve(resultType.getNumElements());
541+
542+
SmallVector<int64_t> srcIndex(startIndex);
543+
for (int64_t i = 0; i < resultType.getNumElements(); ++i) {
544+
auto srcLinearIndex = linearize(srcIndex, dimOffsets);
545+
result.push_back(data[srcLinearIndex]);
546+
for (int64_t dim = srcIndex.size() - 1; dim >= 0; --dim) {
547+
srcIndex[dim] += strides[dim];
548+
if (srcIndex[dim] >= endIndex[dim])
549+
srcIndex[dim] = startIndex[dim];
550+
else
551+
break;
552+
}
553+
}
554+
555+
return DenseElementsAttr::get(op.getResult().getType(),
556+
ArrayRef<ElementType>(result));
557+
}
558+
524559
struct EvalSliceOpPattern : public OpRewritePattern<SliceOp> {
525560
using OpRewritePattern::OpRewritePattern;
526561
LogicalResult matchAndRewrite(SliceOp op,
@@ -529,45 +564,27 @@ struct EvalSliceOpPattern : public OpRewritePattern<SliceOp> {
529564
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
530565
return failure();
531566

532-
if (resultType.getRank() < 1)
533-
return rewriter.notifyMatchFailure(
534-
op, "expected non-0 ranked tensor result type");
535-
536-
auto operand = cast<TypedValue<RankedTensorType>>(op.getOperand());
567+
auto operand = op.getOperand();
537568
RankedTensorType operandType = operand.getType();
538569
if (!operandType.hasStaticShape())
539570
return rewriter.notifyMatchFailure(
540571
op, "expected operand with static ranked tensor type");
541572

542-
// A ranked tensor type with unit dimension prefix of R-1 size is physically
543-
// compatible with 1-dimensional type.
544-
if (!llvm::all_of(resultType.getShape().drop_back(),
545-
[](int64_t s) { return s == 1; }))
573+
ElementsAttr els;
574+
if (!matchPattern(operand, m_Constant(&els)))
546575
return rewriter.notifyMatchFailure(
547-
op, "expected 1-dimensional compatible result type");
548-
549-
SmallVector<APSInt> operandData;
550-
if (failed(hlo::matchInts(operand, operandData)))
551-
return rewriter.notifyMatchFailure(op, "expected constant operand");
552-
553-
const auto dimOffsets = computeSuffixProduct(operandType.getShape());
554-
auto startIndices = op.getStartIndices();
555-
auto limitIndices = op.getLimitIndices();
556-
auto strides = op.getStrides();
557-
558-
int64_t start = 0;
559-
for (size_t i = 0; i < startIndices.size(); ++i)
560-
start += startIndices[i] * dimOffsets[i];
576+
op, "expected constant integer or float operand");
561577

562-
auto slicedDim = operandType.getRank() - 1;
563-
int64_t limit = start + limitIndices[slicedDim] - startIndices[slicedDim];
564-
int64_t stride = strides[slicedDim];
565-
SmallVector<APSInt> result;
566-
for (auto i = start; i < limit; i += stride)
567-
result.push_back(operandData[i]);
578+
DenseElementsAttr resAttr;
579+
if (auto data = els.tryGetValues<APInt>())
580+
resAttr = sliceType(op, *data);
581+
else if (auto data = els.tryGetValues<APFloat>())
582+
resAttr = sliceType(op, *data);
583+
else
584+
return rewriter.notifyMatchFailure(op.getLoc(),
585+
"unsupported element type");
568586

569-
rewriter.replaceOpWithNewOp<ConstantOp>(op,
570-
getTensorAttr(resultType, result));
587+
rewriter.replaceOpWithNewOp<ConstantOp>(op, resAttr);
571588
return success();
572589
}
573590
};

0 commit comments

Comments
 (0)