Skip to content

Commit 969b9ad

Browse files
authored
Add support of tensor types with unit dimension prefix in EvalSlice (#2090)
Edit: Sure, sorry for missing description. Currently our small research team experimenting with BERT model represented in number of *-hlo dialects and wants to simplify it in terms of variety of operators. This PR fixes an issue we stumbled upon: ``` %42 = stablehlo.constant dense<"..."> : tensor<1x512xi64> ... %66 = stablehlo.slice %42 [0:1, 0:128] : (tensor<1x512xi64>) -> tensor<1x128xi64> ``` Previous implementation of EvalSlice can't handle such case -- tensor type prefixed with unit dimension(s) i.e. 1x128. This PR adds support of the above case and can slice from any position, e.g. ``` %0 = stablehlo.constant dense<[[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]]> : tensor<1x3x2x2xi64> %1 = "stablehlo.slice"(%0) { start_indices = array<i64: 0, 1, 1, 0>, limit_indices = array<i64: 1, 2, 2, 2>, strides = array<i64: 1, 1, 1, 1> } : (tensor<1x3x2x2xi64>) -> tensor<1x1x1x2xi64> ``` is folded to ``` %1 = stablehlo.constant dense<[[[[7, 8]]]]> : tensor<1x1x1x2xi64> ```
1 parent 3442dbe commit 969b9ad

File tree

4 files changed

+100
-11
lines changed

4 files changed

+100
-11
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,7 @@ cc_library(
963963
"@llvm-project//mlir:AsmParser",
964964
"@llvm-project//mlir:CommonFolders",
965965
"@llvm-project//mlir:ComplexDialect",
966+
"@llvm-project//mlir:DialectUtils",
966967
"@llvm-project//mlir:FuncDialect",
967968
"@llvm-project//mlir:FunctionInterfaces",
968969
"@llvm-project//mlir:IR",

stablehlo/tests/stablehlo_refine_shapes.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,70 @@ func.func @eval_slice() -> tensor<2xi64> {
424424

425425
// -----
426426

427+
// CHECK-LABEL: func @eval_slice_wild_stride
428+
func.func @eval_slice_wild_stride() -> tensor<1x1x1xi64> {
429+
// CHECK-NOT: stablehlo.slice
430+
// CHECK: [[RESULT:%.*]] = stablehlo.constant dense<2> : tensor<1x1x1xi64>
431+
// CHECK: return [[RESULT]]
432+
%0 = stablehlo.constant dense<[[[1, 2], [3, 4]]]> : tensor<1x2x2xi64>
433+
%1 = "stablehlo.slice"(%0) {
434+
start_indices = array<i64: 0, 0, 1>,
435+
limit_indices = array<i64: 1, 1, 2>,
436+
strides = array<i64: 99, 42, 1>
437+
} : (tensor<1x2x2xi64>) -> tensor<1x1x1xi64>
438+
func.return %1 : tensor<1x1x1xi64>
439+
}
440+
441+
// -----
442+
443+
// CHECK-LABEL: func @eval_slice_unit_prefix
444+
func.func @eval_slice_unit_prefix() -> (tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64>) {
445+
// CHECK-NOT: stablehlo.slice
446+
// CHECK: [[RESULT1:%.*]] = stablehlo.constant dense<{{\[\[\[}}[1, 2]]]]> : tensor<1x1x1x2xi64>
447+
// CHECK: [[RESULT2:%.*]] = stablehlo.constant dense<{{\[\[\[}}[7, 8]]]]> : tensor<1x1x1x2xi64>
448+
// CHECK: [[RESULT3:%.*]] = stablehlo.constant dense<{{\[\[\[}}[11, 12]]]]> : tensor<1x1x1x2xi64>
449+
// CHECK: return [[RESULT1]], [[RESULT2]], [[RESULT3]]
450+
%0 = stablehlo.constant dense<[[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]]> : tensor<1x3x2x2xi64>
451+
452+
%1 = "stablehlo.slice"(%0) {
453+
start_indices = array<i64: 0, 0, 0, 0>,
454+
limit_indices = array<i64: 1, 1, 1, 2>,
455+
strides = array<i64: 1, 1, 1, 1>
456+
} : (tensor<1x3x2x2xi64>) -> tensor<1x1x1x2xi64>
457+
458+
%2 = "stablehlo.slice"(%0) {
459+
start_indices = array<i64: 0, 1, 1, 0>,
460+
limit_indices = array<i64: 1, 2, 2, 2>,
461+
strides = array<i64: 1, 1, 1, 1>
462+
} : (tensor<1x3x2x2xi64>) -> tensor<1x1x1x2xi64>
463+
464+
%3 = "stablehlo.slice"(%0) {
465+
start_indices = array<i64: 0, 2, 1, 0>,
466+
limit_indices = array<i64: 1, 3, 2, 2>,
467+
strides = array<i64: 1, 1, 1, 1>
468+
} : (tensor<1x3x2x2xi64>) -> tensor<1x1x1x2xi64>
469+
470+
func.return %1, %2, %3 : tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64>
471+
}
472+
473+
// -----
474+
475+
// CHECK-LABEL: func @eval_slice_non_unit_prefix
476+
func.func @eval_slice_non_unit_prefix() -> tensor<1x2x1xi64> {
477+
// CHECK: stablehlo.constant {{.*}} : tensor<1x2x2xi64>
478+
// CHECK: [[RESULT:%.*]] = stablehlo.slice{{.*}}
479+
// CHECK: return [[RESULT]]
480+
%0 = stablehlo.constant dense<[[[1, 2], [3, 4]]]> : tensor<1x2x2xi64>
481+
%1 = "stablehlo.slice"(%0) {
482+
start_indices = array<i64: 0, 0, 1>,
483+
limit_indices = array<i64: 1, 2, 2>,
484+
strides = array<i64: 1, 1, 1>
485+
} : (tensor<1x2x2xi64>) -> tensor<1x2x1xi64>
486+
func.return %1 : tensor<1x2x1xi64>
487+
}
488+
489+
// -----
490+
427491
// CHECK-LABEL: func @eval_subtract
428492
func.func @eval_subtract() -> tensor<i64> {
429493
// CHECK-NOT: stablehlo.subtract

stablehlo/transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ add_mlir_dialect_library(StablehloPasses
4545
MLIRArithDialect
4646
MLIRAsmParser
4747
MLIRComplexDialect
48+
MLIRDialectUtils
4849
MLIRFuncDialect
4950
MLIRFunctionInterfaces
5051
MLIRIR

stablehlo/transforms/StablehloRefineShapes.cpp

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ limitations under the License.
2929
#include "llvm/Support/ErrorHandling.h"
3030
#include "llvm/Support/FormatVariadic.h"
3131
#include "mlir/Dialect/Func/IR/FuncOps.h"
32+
#include "mlir/Dialect/Utils/IndexingUtils.h"
3233
#include "mlir/IR/BuiltinAttributes.h"
3334
#include "mlir/IR/BuiltinOps.h"
3435
#include "mlir/IR/BuiltinTypes.h"
@@ -577,20 +578,42 @@ struct EvalSliceOpPattern : public OpRewritePattern<SliceOp> {
577578
LogicalResult matchAndRewrite(SliceOp op,
578579
PatternRewriter& rewriter) const override {
579580
auto resultType = op.getType();
580-
if (!resultType.hasRank() || resultType.getRank() != 1)
581-
return rewriter.notifyMatchFailure(op, "expected 1-dimensional type");
582-
583-
SmallVector<APSInt> operand;
584-
if (failed(hlo::matchInts(op.getOperand(), operand)))
581+
if (resultType.getRank() < 1)
582+
return rewriter.notifyMatchFailure(
583+
op, "expected non-0 ranked tensor result type");
584+
585+
auto operand = op.getOperand().cast<TypedValue<RankedTensorType>>();
586+
RankedTensorType operandType = operand.getType();
587+
if (!operandType.hasStaticShape())
588+
return rewriter.notifyMatchFailure(
589+
op, "expected operand with static ranked tensor type");
590+
591+
// A ranked tensor type with unit dimension prefix of R-1 size is physically
592+
// compatible with 1-dimensional type.
593+
if (!llvm::all_of(resultType.getShape().drop_back(),
594+
[](int64_t s) { return s == 1; }))
595+
return rewriter.notifyMatchFailure(
596+
op, "expected 1-dimensional compatible result type");
597+
598+
SmallVector<APSInt> operandData;
599+
if (failed(hlo::matchInts(operand, operandData)))
585600
return rewriter.notifyMatchFailure(op, "expected constant operand");
586601

587-
int64_t start = op.getStartIndices()[0];
588-
int64_t limit = op.getLimitIndices()[0];
589-
int64_t stride = op.getStrides()[0];
602+
const auto dimOffsets = computeSuffixProduct(operandType.getShape());
603+
auto startIndices = op.getStartIndices();
604+
auto limitIndices = op.getLimitIndices();
605+
auto strides = op.getStrides();
606+
607+
int64_t start = 0;
608+
for (size_t i = 0; i < startIndices.size(); ++i)
609+
start += startIndices[i] * dimOffsets[i];
610+
611+
auto slicedDim = operandType.getRank() - 1;
612+
int64_t limit = start + limitIndices[slicedDim] - startIndices[slicedDim];
613+
int64_t stride = strides[slicedDim];
590614
SmallVector<APSInt> result;
591-
for (auto i = start; i < limit; i += stride) {
592-
result.push_back(operand[i]);
593-
}
615+
for (auto i = start; i < limit; i += stride)
616+
result.push_back(operandData[i]);
594617

595618
rewriter.replaceOpWithNewOp<ConstantOp>(op,
596619
getTensorAttr(resultType, result));

0 commit comments

Comments
 (0)