Skip to content

Commit 3ae71df

Browse files
committed
address comments
1 parent 0993ce9 commit 3ae71df

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,7 +1148,7 @@ struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
11481148
int64_t destDistributedDim =
11491149
getDistributedDim(yieldedType, distributedType);
11501150
assert(destDistributedDim != -1 && "could not find distributed dimension");
1151-
(void)destDistributedDim;
1151+
11521152
VectorType srcType = insertOp.getSourceVectorType();
11531153
VectorType destType = insertOp.getDestVectorType();
11541154
// Currently we require that both source (kD) and dest (nD) vectors are
@@ -1242,7 +1242,9 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
12421242
// Distributed dimension must be fully extracted.
12431243
// TODO: Partial extraction from distributed dimension require cross lane
12441244
// communication.
1245-
if (distributedDim < static_cast<int64_t>(extractOp.getSizes().size())) {
1245+
int64_t extractedDimsRank =
1246+
static_cast<int64_t>(extractOp.getSizes().size());
1247+
if (distributedDim < extractedDimsRank) {
12461248
int64_t distributedDimOffset =
12471249
llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
12481250
.getInt();

0 commit comments

Comments
 (0)