File tree Expand file tree Collapse file tree 1 file changed +3
-1
lines changed Expand file tree Collapse file tree 1 file changed +3
-1
lines changed Original file line number Diff line number Diff line change @@ -209,12 +209,14 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout) {
209209 mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
210210 auto sizePerThread = distributedLayout.getSizePerThread ();
211211 auto threadsPerWarp = distributedLayout.getThreadsPerWarp ();
212+ auto warpsPerCTA = distributedLayout.getWarpsPerCTA ();
212213 // ThreadsPerWarp does not align with this function for slice layout
213214 if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
214215 threadsPerWarp = getThreadsPerWarp (sliceLayout.getParent ());
215216 threadsPerWarp.erase (threadsPerWarp.begin () + sliceLayout.getDim ());
217+ warpsPerCTA = getWarpsPerCTA (sliceLayout.getParent ());
218+ warpsPerCTA.erase (warpsPerCTA.begin () + sliceLayout.getDim ());
216219 }
217- auto warpsPerCTA = distributedLayout.getWarpsPerCTA ();
218220 assert (sizePerThread.size () == threadsPerWarp.size () &&
219221 sizePerThread.size () == warpsPerCTA.size ());
220222 SmallVector<unsigned > shape;
You can’t perform that action at this time.
0 commit comments