Skip to content

Commit 7e21f87

Browse files
committed
Fix bug of getShapePerCTATile for slice layout
1 parent 3cf40df commit 7e21f87

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,14 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout) {
209209
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
210210
auto sizePerThread = distributedLayout.getSizePerThread();
211211
auto threadsPerWarp = distributedLayout.getThreadsPerWarp();
212+
auto warpsPerCTA = distributedLayout.getWarpsPerCTA();
212213
// ThreadsPerWarp does not align with this function for slice layout
213214
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
214215
threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent());
215216
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
217+
warpsPerCTA = getWarpsPerCTA(sliceLayout.getParent());
218+
warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
216219
}
217-
auto warpsPerCTA = distributedLayout.getWarpsPerCTA();
218220
assert(sizePerThread.size() == threadsPerWarp.size() &&
219221
sizePerThread.size() == warpsPerCTA.size());
220222
SmallVector<unsigned> shape;

0 commit comments

Comments
 (0)