@@ -1167,23 +1167,19 @@ struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
11671167 return rewriter.notifyMatchFailure (
11681168 insertOp, " distributed dimension must be fully inserted" );
11691169 SmallVector<int64_t > newSourceDistShape (
1170- insertOp.getSourceVectorType ().getShape ()),
1171- newDestDistShape (insertOp.getDestVectorType ().getShape ());
1170+ insertOp.getSourceVectorType ().getShape ());
11721171 newSourceDistShape[sourceDistributedDim] =
11731172 distributedType.getDimSize (destDistributedDim);
1174- newDestDistShape[destDistributedDim] =
1175- distributedType.getDimSize (destDistributedDim);
11761173 auto newSourceTy =
11771174 VectorType::get (newSourceDistShape, distributedType.getElementType ());
1178- auto newDestTy =
1179- VectorType::get (newDestDistShape, distributedType.getElementType ());
1175+ VectorType newDestTy = distributedType;
11801176 SmallVector<size_t > newRetIndices;
11811177 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
11821178 rewriter, warpOp, {insertOp.getValueToStore (), insertOp.getDest ()},
11831179 {newSourceTy, newDestTy}, newRetIndices);
11841180 rewriter.setInsertionPointAfter (newWarpOp);
1185- auto distributedSource = newWarpOp->getResult (newRetIndices[0 ]);
1186- auto distributedDest = newWarpOp->getResult (newRetIndices[1 ]);
1181+ Value distributedSource = newWarpOp->getResult (newRetIndices[0 ]);
1182+ Value distributedDest = newWarpOp->getResult (newRetIndices[1 ]);
11871183 // Create a new insert strided slice op that inserts distributed source into
11881184 // distributed dest.
11891185 Value newInsert = rewriter.create <vector::InsertStridedSliceOp>(
0 commit comments