Skip to content

Commit 535092b

Browse files
committed
address comments
1 parent 27438d3 commit 535092b

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,23 +1167,19 @@ struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
11671167
return rewriter.notifyMatchFailure(
11681168
insertOp, "distributed dimension must be fully inserted");
11691169
SmallVector<int64_t> newSourceDistShape(
1170-
insertOp.getSourceVectorType().getShape()),
1171-
newDestDistShape(insertOp.getDestVectorType().getShape());
1170+
insertOp.getSourceVectorType().getShape());
11721171
newSourceDistShape[sourceDistributedDim] =
11731172
distributedType.getDimSize(destDistributedDim);
1174-
newDestDistShape[destDistributedDim] =
1175-
distributedType.getDimSize(destDistributedDim);
11761173
auto newSourceTy =
11771174
VectorType::get(newSourceDistShape, distributedType.getElementType());
1178-
auto newDestTy =
1179-
VectorType::get(newDestDistShape, distributedType.getElementType());
1175+
VectorType newDestTy = distributedType;
11801176
SmallVector<size_t> newRetIndices;
11811177
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
11821178
rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
11831179
{newSourceTy, newDestTy}, newRetIndices);
11841180
rewriter.setInsertionPointAfter(newWarpOp);
1185-
auto distributedSource = newWarpOp->getResult(newRetIndices[0]);
1186-
auto distributedDest = newWarpOp->getResult(newRetIndices[1]);
1181+
Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1182+
Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
11871183
// Create a new insert strided slice op that inserts distributed source into
11881184
// distributed dest.
11891185
Value newInsert = rewriter.create<vector::InsertStridedSliceOp>(

0 commit comments

Comments
 (0)