@@ -521,6 +521,41 @@ struct EvalSignOpPattern : public OpRewritePattern<SignOp> {
521521 }
522522};
523523
524+ template <typename RangeType>
525+ DenseElementsAttr sliceType (SliceOp& op, const RangeType& data) {
526+ using ElementType = std::decay_t <decltype (*std::begin (data))>;
527+
528+ RankedTensorType operandType = op.getOperand ().getType ();
529+ RankedTensorType resultType = op.getResult ().getType ();
530+
531+ const auto dimOffsets = computeStrides (operandType.getShape ());
532+ auto startIndices = op.getStartIndices ();
533+ auto limitIndices = op.getLimitIndices ();
534+ auto strides = op.getStrides ();
535+
536+ const SmallVector<int64_t > startIndex (startIndices);
537+ const SmallVector<int64_t > endIndex (limitIndices);
538+
539+ SmallVector<ElementType> result;
540+ result.reserve (resultType.getNumElements ());
541+
542+ SmallVector<int64_t > srcIndex (startIndex);
543+ for (int64_t i = 0 ; i < resultType.getNumElements (); ++i) {
544+ auto srcLinearIndex = linearize (srcIndex, dimOffsets);
545+ result.push_back (data[srcLinearIndex]);
546+ for (int64_t dim = srcIndex.size () - 1 ; dim >= 0 ; --dim) {
547+ srcIndex[dim] += strides[dim];
548+ if (srcIndex[dim] >= endIndex[dim])
549+ srcIndex[dim] = startIndex[dim];
550+ else
551+ break ;
552+ }
553+ }
554+
555+ return DenseElementsAttr::get (op.getResult ().getType (),
556+ ArrayRef<ElementType>(result));
557+ }
558+
524559struct EvalSliceOpPattern : public OpRewritePattern <SliceOp> {
525560 using OpRewritePattern::OpRewritePattern;
526561 LogicalResult matchAndRewrite (SliceOp op,
@@ -529,45 +564,27 @@ struct EvalSliceOpPattern : public OpRewritePattern<SliceOp> {
529564 if (failed (validateResultTypeForEval (rewriter, op, resultType)))
530565 return failure ();
531566
532- if (resultType.getRank () < 1 )
533- return rewriter.notifyMatchFailure (
534- op, " expected non-0 ranked tensor result type" );
535-
536- auto operand = cast<TypedValue<RankedTensorType>>(op.getOperand ());
567+ auto operand = op.getOperand ();
537568 RankedTensorType operandType = operand.getType ();
538569 if (!operandType.hasStaticShape ())
539570 return rewriter.notifyMatchFailure (
540571 op, " expected operand with static ranked tensor type" );
541572
542- // A ranked tensor type with unit dimension prefix of R-1 size is physically
543- // compatible with 1-dimensional type.
544- if (!llvm::all_of (resultType.getShape ().drop_back (),
545- [](int64_t s) { return s == 1 ; }))
573+ ElementsAttr els;
574+ if (!matchPattern (operand, m_Constant (&els)))
546575 return rewriter.notifyMatchFailure (
547- op, " expected 1-dimensional compatible result type" );
548-
549- SmallVector<APSInt> operandData;
550- if (failed (hlo::matchInts (operand, operandData)))
551- return rewriter.notifyMatchFailure (op, " expected constant operand" );
552-
553- const auto dimOffsets = computeSuffixProduct (operandType.getShape ());
554- auto startIndices = op.getStartIndices ();
555- auto limitIndices = op.getLimitIndices ();
556- auto strides = op.getStrides ();
557-
558- int64_t start = 0 ;
559- for (size_t i = 0 ; i < startIndices.size (); ++i)
560- start += startIndices[i] * dimOffsets[i];
576+ op, " expected constant integer or float operand" );
561577
562- auto slicedDim = operandType.getRank () - 1 ;
563- int64_t limit = start + limitIndices[slicedDim] - startIndices[slicedDim];
564- int64_t stride = strides[slicedDim];
565- SmallVector<APSInt> result;
566- for (auto i = start; i < limit; i += stride)
567- result.push_back (operandData[i]);
578+ DenseElementsAttr resAttr;
579+ if (auto data = els.tryGetValues <APInt>())
580+ resAttr = sliceType (op, *data);
581+ else if (auto data = els.tryGetValues <APFloat>())
582+ resAttr = sliceType (op, *data);
583+ else
584+ return rewriter.notifyMatchFailure (op.getLoc (),
585+ " unsupported element type" );
568586
569- rewriter.replaceOpWithNewOp <ConstantOp>(op,
570- getTensorAttr (resultType, result));
587+ rewriter.replaceOpWithNewOp <ConstantOp>(op, resAttr);
571588 return success ();
572589 }
573590};
0 commit comments