@@ -55,12 +55,16 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,
5555 return map;
5656}
5757
58- static int getDistributedDim (VectorType origType, VectorType distributedType) {
59- assert (origType.getRank () == distributedType.getRank () &&
58+ // / Given a sequential and distributed vector type, returns the distributed
59+ // / dimension. This function expects that only a single dimension is
60+ // / distributed.
61+ static int getDistributedDim (VectorType sequentialType,
62+ VectorType distributedType) {
63+ assert (sequentialType.getRank () == distributedType.getRank () &&
6064 " sequential and distributed vector types must have the same rank" );
6165 int64_t distributedDim = -1 ;
62- for (int64_t i = 0 ; i < origType .getRank (); ++i) {
63- if (distributedType.getDimSize (i) != origType .getDimSize (i)) {
66+ for (int64_t i = 0 ; i < sequentialType .getRank (); ++i) {
67+ if (distributedType.getDimSize (i) != sequentialType .getDimSize (i)) {
6468 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
6569 // support distributing multiple dimensions in the future.
6670 assert (distributedDim == -1 && " found multiple distributed dims" );
@@ -1234,7 +1238,6 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
12341238 auto yieldedType = cast<VectorType>(operand->get ().getType ());
12351239 int64_t distributedDim = getDistributedDim (yieldedType, distributedType);
12361240 assert (distributedDim != -1 && " could not find distributed dimension" );
1237- (void )distributedDim;
12381241
12391242 // Distributed dimension must be fully extracted.
12401243 // TODO: Partial extraction from distributed dimension require cross lane
0 commit comments