Skip to content

Commit b96538c

Browse files
committed
add comment
1 parent 4dfbe7c commit b96538c

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,12 +1507,13 @@ struct VectorExtractStridedSliceDistribution
15071507
auto extractResultType = cast<VectorType>(operand->get().getType());
15081508
auto distributedDims =
15091509
getDistributedDims(extractResultType, distributedType);
1510-
// Source distributed type must be adjusted for the distributed case.
1511-
VectorType sourceDistType = extractOp.getSourceVectorType();
1512-
// Distributed sizes and offsets must be adjusted for distributed case.
1513-
SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1510+
// Collect updated source type, sizes and offsets. They may be adjusted
1511+
// later if the data is distributed to lanes (as opposed to being owned by
1512+
// all lanes uniformly).
1513+
VectorType updatedSourceType = extractOp.getSourceVectorType();
1514+
SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
15141515
extractOp.getSizes(), [](Attribute attr) { return attr; });
1515-
SmallVector<Attribute> distributedOffsets = llvm::map_to_vector(
1516+
SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
15161517
extractOp.getOffsets(), [](Attribute attr) { return attr; });
15171518
// If the result is distributed, it must be distributed in exactly one
15181519
// dimension. In this case, we adjust the sourceDistType, distributedSizes
@@ -1554,31 +1555,30 @@ struct VectorExtractStridedSliceDistribution
15541555
return rewriter.notifyMatchFailure(
15551556
warpOp, "Offset along distributed dimension "
15561557
"is not a multiple of subgroup size.");
1557-
// Do the distribution by yielding the source of the extract op from
1558-
// the warp op and creating a new extract op outside the warp op.
1559-
sourceDistType = getDistVecTypeBasedOnLaneLayout(
1560-
sourceLayout, extractOp.getSourceVectorType())
1561-
.value();
1558+
updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1559+
sourceLayout, extractOp.getSourceVectorType())
1560+
.value();
15621561
// Update the distributed sizes to match the distributed type.
1563-
distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1562+
updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
15641563
distributedType.getDimSize(distributedDim));
15651564
// Update the distributed offsets to match round robin distribution (i.e.
15661565
// each lane owns data at `subgroupSize` stride given unit lane data).
1567-
distributedOffsets[distributedDim] =
1566+
updatedOffsets[distributedDim] =
15681567
rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
15691568
}
1570-
// Create a new warp op that yields the source of the extract op.
1569+
// Do the distribution by yielding the source of the extract op from
1570+
// the warp op and creating a new extract op outside the warp op.
15711571
SmallVector<size_t> newRetIndices;
15721572
auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1573-
rewriter, warpOp, {extractOp.getSource()}, {sourceDistType},
1573+
rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
15741574
newRetIndices);
15751575
rewriter.setInsertionPointAfter(newWarpOp);
15761576
Value source = newWarpOp.getResult(newRetIndices[0]);
15771577
// Create a new extract op outside the warp op.
15781578
Value newExtractOp = vector::ExtractStridedSliceOp::create(
15791579
rewriter, extractOp.getLoc(), distributedType, source,
1580-
ArrayAttr::get(rewriter.getContext(), distributedOffsets),
1581-
ArrayAttr::get(rewriter.getContext(), distributedSizes),
1580+
ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1581+
ArrayAttr::get(rewriter.getContext(), updatedSizes),
15821582
extractOp.getStrides());
15831583
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
15841584
return success();
@@ -1608,7 +1608,7 @@ struct VectorInsertStridedSliceDistribution
16081608
auto insertResultType = cast<VectorType>(operand->get().getType());
16091609
auto destDistributedDims =
16101610
getDistributedDims(insertResultType, distributedType);
1611-
// Collect updated offsets, source type and dest type. They may be updated
1611+
// Collect updated offsets, source type and dest type. They may be adjusted
16121612
// later if the data is distributed to lanes (as opposed to being owned by
16131613
// all lanes uniformly).
16141614
SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(

0 commit comments

Comments
 (0)