@@ -1507,12 +1507,13 @@ struct VectorExtractStridedSliceDistribution
15071507 auto extractResultType = cast<VectorType>(operand->get ().getType ());
15081508 auto distributedDims =
15091509 getDistributedDims (extractResultType, distributedType);
1510- // Source distributed type must be adjusted for the distributed case.
1511- VectorType sourceDistType = extractOp.getSourceVectorType ();
1512- // Distributed sizes and offsets must be adjusted for distributed case.
1513- SmallVector<Attribute> distributedSizes = llvm::map_to_vector (
1510+ // Collect updated source type, sizes and offsets. They may be adjusted
1511+ // later if the data is distributed to lanes (as opposed to being owned by
1512+ // all lanes uniformly).
1513+ VectorType updatedSourceType = extractOp.getSourceVectorType ();
1514+ SmallVector<Attribute> updatedSizes = llvm::map_to_vector (
15141515 extractOp.getSizes (), [](Attribute attr) { return attr; });
1515- SmallVector<Attribute> distributedOffsets = llvm::map_to_vector (
1516+ SmallVector<Attribute> updatedOffsets = llvm::map_to_vector (
15161517 extractOp.getOffsets (), [](Attribute attr) { return attr; });
15171518 // If the result is distributed, it must be distributed in exactly one
15181519 // dimension. In this case, we adjust the sourceDistType, distributedSizes
@@ -1554,31 +1555,30 @@ struct VectorExtractStridedSliceDistribution
15541555 return rewriter.notifyMatchFailure (
15551556 warpOp, " Offset along distributed dimension "
15561557 " is not a multiple of subgroup size." );
1557- // Do the distribution by yielding the source of the extract op from
1558- // the warp op and creating a new extract op outside the warp op.
1559- sourceDistType = getDistVecTypeBasedOnLaneLayout (
1560- sourceLayout, extractOp.getSourceVectorType ())
1561- .value ();
1558+ updatedSourceType = getDistVecTypeBasedOnLaneLayout (
1559+ sourceLayout, extractOp.getSourceVectorType ())
1560+ .value ();
15621561 // Update the distributed sizes to match the distributed type.
1563- distributedSizes [distributedDim] = rewriter.getI64IntegerAttr (
1562+ updatedSizes [distributedDim] = rewriter.getI64IntegerAttr (
15641563 distributedType.getDimSize (distributedDim));
15651564 // Update the distributed offsets to match round robin distribution (i.e.
15661565 // each lane owns data at `subgroupSize` stride given unit lane data).
1567- distributedOffsets [distributedDim] =
1566+ updatedOffsets [distributedDim] =
15681567 rewriter.getI64IntegerAttr (distrDimOffset / subgroupSize);
15691568 }
1570- // Create a new warp op that yields the source of the extract op.
1569+ // Do the distribution by yielding the source of the extract op from
1570+ // the warp op and creating a new extract op outside the warp op.
15711571 SmallVector<size_t > newRetIndices;
15721572 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1573- rewriter, warpOp, {extractOp.getSource ()}, {sourceDistType },
1573+ rewriter, warpOp, {extractOp.getSource ()}, {updatedSourceType },
15741574 newRetIndices);
15751575 rewriter.setInsertionPointAfter (newWarpOp);
15761576 Value source = newWarpOp.getResult (newRetIndices[0 ]);
15771577 // Create a new extract op outside the warp op.
15781578 Value newExtractOp = vector::ExtractStridedSliceOp::create (
15791579 rewriter, extractOp.getLoc (), distributedType, source,
1580- ArrayAttr::get (rewriter.getContext (), distributedOffsets ),
1581- ArrayAttr::get (rewriter.getContext (), distributedSizes ),
1580+ ArrayAttr::get (rewriter.getContext (), updatedOffsets ),
1581+ ArrayAttr::get (rewriter.getContext (), updatedSizes ),
15821582 extractOp.getStrides ());
15831583 rewriter.replaceAllUsesWith (newWarpOp.getResult (operandIdx), newExtractOp);
15841584 return success ();
@@ -1608,7 +1608,7 @@ struct VectorInsertStridedSliceDistribution
16081608 auto insertResultType = cast<VectorType>(operand->get ().getType ());
16091609 auto destDistributedDims =
16101610 getDistributedDims (insertResultType, distributedType);
1611- // Collect updated offsets, source type and dest type. They may be updated
1611+ // Collect updated offsets, source type and dest type. They may be adjusted
16121612 // later if the data is distributed to lanes (as opposed to being owned by
16131613 // all lanes uniformly).
16141614 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector (
0 commit comments