@@ -177,13 +177,13 @@ static bool requireTranspose(const xegpu::LayoutAttr layout,
177177
178178// / Given a sequential and distributed vector type, return the list of
179179// / dimensions that are distributed.
180- static SmallVector<int64_t > getDistributedDims (VectorType sequentialType ,
180+ static SmallVector<int64_t > getDistributedDims (VectorType originalType ,
181181 VectorType distributedType) {
182- assert (sequentialType .getRank () == distributedType.getRank () &&
182+ assert (originalType .getRank () == distributedType.getRank () &&
183183 " sequential and distributed vector types must have the same rank" );
184184 SmallVector<int64_t > distributedDims;
185- for (int64_t i = 0 ; i < sequentialType .getRank (); ++i) {
186- if (distributedType.getDimSize (i) != sequentialType .getDimSize (i)) {
185+ for (int64_t i = 0 ; i < originalType .getRank (); ++i) {
186+ if (distributedType.getDimSize (i) != originalType .getDimSize (i)) {
187187 distributedDims.push_back (i);
188188 }
189189 }
@@ -1486,9 +1486,9 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
14861486};
14871487
14881488// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
1489- // enclosing `gpu.warp_execute_on_lane_0` region. This pattern only handles
1490- // advanced cases where the distributed is partially extracted and currently not
1491- // supported by the generic vector distribution patterns.
1489+ // enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1490+ // advanced cases where the distributed dimension is partially extracted and
1491+ // currently not supported by the generic vector distribution patterns.
14921492struct VectorExtractStridedSliceDistribution
14931493 : public gpu::WarpDistributionPattern {
14941494 using gpu::WarpDistributionPattern::WarpDistributionPattern;
@@ -1503,7 +1503,7 @@ struct VectorExtractStridedSliceDistribution
15031503 unsigned operandIdx = operand->getOperandNumber ();
15041504 auto distributedType =
15051505 cast<VectorType>(warpOp.getResult (operandIdx).getType ());
1506- // Find the distributed dimension. There should be exactly one .
1506+ // Find the distributed dimensions .
15071507 auto extractResultType = cast<VectorType>(operand->get ().getType ());
15081508 auto distributedDims =
15091509 getDistributedDims (extractResultType, distributedType);
@@ -1586,7 +1586,7 @@ struct VectorExtractStridedSliceDistribution
15861586};
15871587
15881588// / Distribute a `vector.insert_strided_slice` op feeding into yield op of an
1589- // / enclosing `gpu.warp_execute_on_lane_0` region. This pattern only handles
1589+ // / enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
15901590// / advanced cases where the distributed dimension is partially inserted and
15911591// / currently not supported by the generic vector distribution patterns.
15921592struct VectorInsertStridedSliceDistribution
@@ -1603,8 +1603,7 @@ struct VectorInsertStridedSliceDistribution
16031603 operand->get ().getDefiningOp <vector::InsertStridedSliceOp>();
16041604 auto distributedType =
16051605 cast<VectorType>(warpOp.getResult (operandNumber).getType ());
1606- // Find the distributed dimension of the dest vector. There should be
1607- // exactly one.
1606+ // Find the distributed dimensions of the dest vector.
16081607 auto insertResultType = cast<VectorType>(operand->get ().getType ());
16091608 auto destDistributedDims =
16101609 getDistributedDims (insertResultType, distributedType);
0 commit comments