File tree Expand file tree Collapse file tree 1 file changed +7
-4
lines changed
mlir/lib/Dialect/Vector/Transforms Expand file tree Collapse file tree 1 file changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -1240,12 +1240,15 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
12401240 int64_t distributedDim = getDistributedDim (yieldedType, distributedType);
12411241 assert (distributedDim != -1 && " could not find distributed dimension" );
12421242
1243- // Distributed dimension must be fully extracted.
1243+ int64_t numOfExtractedDims =
1244+ static_cast <int64_t >(extractOp.getSizes ().size ());
1245+ // If the distributed dim is included in the extracted dims, then we make
1246+ // sure distributed dim is fully extracted. If distributed dim is not
1247+ // included in extracted dims, it is guaranteed to be fully extracted (i.e.
1248+ // distributed dim comes after all the extracted dims)
12441249 // TODO: Partial extraction from distributed dimension require cross lane
12451250 // communication.
1246- int64_t extractedDimsRank =
1247- static_cast <int64_t >(extractOp.getSizes ().size ());
1248- if (distributedDim < extractedDimsRank) {
1251+ if (distributedDim < numOfExtractedDims) {
12491252 int64_t distributedDimOffset =
12501253 llvm::cast<IntegerAttr>(extractOp.getOffsets ()[distributedDim])
12511254 .getInt ();
You can’t perform that action at this time.
0 commit comments