Skip to content

Commit 92fe035

Browse files
committed
address comments
1 parent 36c2382 commit 92fe035

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,13 @@ static bool requireTranspose(const xegpu::LayoutAttr layout,
177177

178178
/// Given a sequential and distributed vector type, return the list of
179179
/// dimensions that are distributed.
180-
static SmallVector<int64_t> getDistributedDims(VectorType sequentialType,
180+
static SmallVector<int64_t> getDistributedDims(VectorType originalType,
181181
VectorType distributedType) {
182-
assert(sequentialType.getRank() == distributedType.getRank() &&
182+
assert(originalType.getRank() == distributedType.getRank() &&
183183
"sequential and distributed vector types must have the same rank");
184184
SmallVector<int64_t> distributedDims;
185-
for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
186-
if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
185+
for (int64_t i = 0; i < originalType.getRank(); ++i) {
186+
if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
187187
distributedDims.push_back(i);
188188
}
189189
}
@@ -1486,9 +1486,9 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
14861486
};
14871487

14881488
// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
1489-
// enclosing `gpu.warp_execute_on_lane_0` region. This pattern only handles
1490-
// advanced cases where the distributed is partially extracted and currently not
1491-
// supported by the generic vector distribution patterns.
1489+
// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1490+
// advanced cases where the distributed dimension is partially extracted and
1491+
// currently not supported by the generic vector distribution patterns.
14921492
struct VectorExtractStridedSliceDistribution
14931493
: public gpu::WarpDistributionPattern {
14941494
using gpu::WarpDistributionPattern::WarpDistributionPattern;
@@ -1503,7 +1503,7 @@ struct VectorExtractStridedSliceDistribution
15031503
unsigned operandIdx = operand->getOperandNumber();
15041504
auto distributedType =
15051505
cast<VectorType>(warpOp.getResult(operandIdx).getType());
1506-
// Find the distributed dimension. There should be exactly one.
1506+
// Find the distributed dimensions.
15071507
auto extractResultType = cast<VectorType>(operand->get().getType());
15081508
auto distributedDims =
15091509
getDistributedDims(extractResultType, distributedType);
@@ -1586,7 +1586,7 @@ struct VectorExtractStridedSliceDistribution
15861586
};
15871587

15881588
/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an
1589-
/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern only handles
1589+
/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
15901590
/// advanced cases where the distributed dimension is partially inserted and
15911591
/// currently not supported by the generic vector distribution patterns.
15921592
struct VectorInsertStridedSliceDistribution
@@ -1603,8 +1603,7 @@ struct VectorInsertStridedSliceDistribution
16031603
operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
16041604
auto distributedType =
16051605
cast<VectorType>(warpOp.getResult(operandNumber).getType());
1606-
// Find the distributed dimension of the dest vector. There should be
1607-
// exactly one.
1606+
// Find the distributed dimensions of the dest vector.
16081607
auto insertResultType = cast<VectorType>(operand->get().getType());
16091608
auto destDistributedDims =
16101609
getDistributedDims(insertResultType, distributedType);

0 commit comments

Comments
 (0)