Skip to content

Commit 99cbbe5

Browse files
committed
address comments
1 parent b087820 commit 99cbbe5

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,12 +1240,15 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
12401240
int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
12411241
assert(distributedDim != -1 && "could not find distributed dimension");
12421242

1243-
// Distributed dimension must be fully extracted.
1243+
int64_t numOfExtractedDims =
1244+
static_cast<int64_t>(extractOp.getSizes().size());
1245+
// If the distributed dim is included in the extracted dims, then we make
1246+
// sure distributed dim is fully extracted. If distributed dim is not
1247+
// included in extracted dims, it is guaranteed to be fully extracted (i.e.
1248+
// distributed dim comes after all the extracted dims)
12441249
// TODO: Partial extraction from distributed dimension require cross lane
12451250
// communication.
1246-
int64_t extractedDimsRank =
1247-
static_cast<int64_t>(extractOp.getSizes().size());
1248-
if (distributedDim < extractedDimsRank) {
1251+
if (distributedDim < numOfExtractedDims) {
12491252
int64_t distributedDimOffset =
12501253
llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
12511254
.getInt();

0 commit comments

Comments
 (0)