@@ -2291,9 +2291,9 @@ void ExtractSliceOp::getAsmResultNames(
22912291// / An extract_slice result type can be inferred, when it is not
22922292// / rank-reduced, from the source type and the static representation of
22932293// / offsets, sizes and strides. Special sentinels encode the dynamic case.
2294- RankedTensorType ExtractSliceOp::inferResultType (
2295- RankedTensorType sourceTensorType, ArrayRef< int64_t > staticOffsets ,
2296- ArrayRef< int64_t > staticSizes, ArrayRef<int64_t > staticStrides ) {
2294+ RankedTensorType
2295+ ExtractSliceOp::inferResultType ( RankedTensorType sourceTensorType,
2296+ ArrayRef<int64_t > staticSizes ) {
22972297 // An extract_slice op may specify only a leading subset of offset/sizes/
22982298 // strides in which case we complete with offset=0, sizes from memref type
22992299 // and strides=1.
@@ -2305,11 +2305,12 @@ RankedTensorType ExtractSliceOp::inferResultType(
23052305}
23062306
23072307// TODO: This uses neither offsets nor strides!
2308- RankedTensorType ExtractSliceOp::inferResultType (
2309- RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets ,
2310- ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides ) {
2308+ RankedTensorType
2309+ ExtractSliceOp::inferResultType ( RankedTensorType sourceTensorType,
2310+ ArrayRef<OpFoldResult> sizes ) {
23112311 SmallVector<int64_t > staticSizes;
23122312 std::tie (staticSizes, std::ignore) = decomposeMixedValues (sizes);
2313+
23132314 assert (static_cast <int64_t >(staticSizes.size ()) ==
23142315 sourceTensorType.getRank () &&
23152316 " unexpected staticSizes not equal to rank of source" );
@@ -2327,11 +2328,10 @@ RankedTensorType ExtractSliceOp::inferResultType(
23272328// / To disambiguate, this function always drops the first 1 sizes occurrences.
23282329RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType (
23292330 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2330- ArrayRef<int64_t > offsets, ArrayRef<int64_t > sizes,
2331- ArrayRef<int64_t > strides) {
2331+ ArrayRef<int64_t > sizes) {
23322332 // Type inferred in the absence of rank-reducing behavior.
23332333 auto inferredType = llvm::cast<RankedTensorType>(
2334- inferResultType (sourceRankedTensorType, offsets, sizes, strides ));
2334+ inferResultType (sourceRankedTensorType, sizes));
23352335 int rankDiff = inferredType.getRank () - desiredResultRank;
23362336 if (rankDiff > 0 ) {
23372337 auto shape = inferredType.getShape ();
@@ -2350,16 +2350,12 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
23502350
23512351RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType (
23522352 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2353- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2354- ArrayRef<OpFoldResult> strides) {
2355- SmallVector<int64_t > staticOffsets, staticSizes, staticStrides;
2356- SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2357- dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
2353+ ArrayRef<OpFoldResult> sizes) {
2354+ SmallVector<int64_t > staticSizes;
2355+ SmallVector<Value> dynamicSizes;
23582356 dispatchIndexOpFoldResults (sizes, dynamicSizes, staticSizes);
2359- dispatchIndexOpFoldResults (strides, dynamicStrides, staticStrides);
23602357 return ExtractSliceOp::inferCanonicalRankReducedResultType (
2361- desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2362- staticStrides);
2358+ desiredResultRank, sourceRankedTensorType, staticSizes);
23632359}
23642360
23652361// / Build an ExtractSliceOp with mixed static and dynamic entries and custom
@@ -2378,8 +2374,8 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
23782374 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType ());
23792375 // Structuring implementation this way avoids duplication between builders.
23802376 if (!resultType) {
2381- resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType (
2382- sourceRankedTensorType, staticOffsets, staticSizes, staticStrides ));
2377+ resultType = llvm::cast<RankedTensorType>(
2378+ ExtractSliceOp::inferResultType ( sourceRankedTensorType, staticSizes));
23832379 }
23842380 result.addAttributes (attrs);
23852381 build (b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2454,8 +2450,8 @@ LogicalResult ExtractSliceOp::verify() {
24542450 RankedTensorType sourceType = getSourceType ();
24552451
24562452 // Verify result type against inferred type.
2457- RankedTensorType expectedType = ExtractSliceOp::inferResultType (
2458- sourceType, getMixedOffsets (), getMixedSizes (), getMixedStrides ());
2453+ RankedTensorType expectedType =
2454+ ExtractSliceOp::inferResultType ( sourceType, getMixedSizes ());
24592455 SliceVerificationResult result = isRankReducedType (expectedType, getType ());
24602456 if (result != SliceVerificationResult::Success)
24612457 return produceSliceErrorMsg (result, *this , expectedType);
@@ -2695,8 +2691,7 @@ struct SliceReturnTypeCanonicalizer {
26952691 ArrayRef<OpFoldResult> mixedSizes,
26962692 ArrayRef<OpFoldResult> mixedStrides) {
26972693 return ExtractSliceOp::inferCanonicalRankReducedResultType (
2698- op.getType ().getRank (), op.getSourceType (), mixedOffsets, mixedSizes,
2699- mixedStrides);
2694+ op.getType ().getRank (), op.getSourceType (), mixedSizes);
27002695 }
27012696};
27022697
@@ -2837,8 +2832,8 @@ static SliceVerificationResult verifyInsertSliceOp(
28372832 ArrayRef<int64_t > staticStrides, RankedTensorType *expectedType = nullptr ) {
28382833 // insert_slice is the inverse of extract_slice, use the same type
28392834 // inference.
2840- RankedTensorType expected = ExtractSliceOp::inferResultType (
2841- dstType, staticOffsets, staticSizes, staticStrides );
2835+ RankedTensorType expected =
2836+ ExtractSliceOp::inferResultType ( dstType, staticSizes);
28422837 if (expectedType)
28432838 *expectedType = expected;
28442839 return isRankReducedType (expected, srcType);
@@ -2966,7 +2961,7 @@ class InsertSliceOpConstantArgumentFolder final
29662961 // Create the new op in canonical form.
29672962 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType (
29682963 insertSliceOp.getSourceType ().getRank (), insertSliceOp.getDestType (),
2969- mixedOffsets, mixedSizes, mixedStrides );
2964+ mixedSizes);
29702965 Value toInsert = insertSliceOp.getSource ();
29712966 if (sourceType != insertSliceOp.getSourceType ()) {
29722967 OpBuilder::InsertionGuard g (rewriter);
0 commit comments