File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -209,12 +209,15 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout) {
209209 mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
210210 auto sizePerThread = distributedLayout.getSizePerThread ();
211211 auto threadsPerWarp = distributedLayout.getThreadsPerWarp ();
212- // ThreadsPerWarp does not align with this function for slice layout
212+ auto warpsPerCTA = distributedLayout.getWarpsPerCTA ();
213+ // ThreadsPerWarp and warpsPerCTA does not align with this function for
214+ // slice layout
213215 if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
214216 threadsPerWarp = getThreadsPerWarp (sliceLayout.getParent ());
215217 threadsPerWarp.erase (threadsPerWarp.begin () + sliceLayout.getDim ());
218+ warpsPerCTA = getWarpsPerCTA (sliceLayout.getParent ());
219+ warpsPerCTA.erase (warpsPerCTA.begin () + sliceLayout.getDim ());
216220 }
217- auto warpsPerCTA = distributedLayout.getWarpsPerCTA ();
218221 assert (sizePerThread.size () == threadsPerWarp.size () &&
219222 sizePerThread.size () == warpsPerCTA.size ());
220223 SmallVector<unsigned > shape;
You can’t perform that action at this time.
0 commit comments