Skip to content

Commit d093220

Browse files
chengjunluwhitneywhtsang
authored andcommitted
Fix bug of getShapePerCTATile for slice layout
1 parent 6e2da6a commit d093220

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,15 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout) {
209209
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
210210
auto sizePerThread = distributedLayout.getSizePerThread();
211211
auto threadsPerWarp = distributedLayout.getThreadsPerWarp();
212-
// ThreadsPerWarp does not align with this function for slice layout
212+
auto warpsPerCTA = distributedLayout.getWarpsPerCTA();
213+
// ThreadsPerWarp and warpsPerCTA does not align with this function for
214+
// slice layout
213215
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
214216
threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent());
215217
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
218+
warpsPerCTA = getWarpsPerCTA(sliceLayout.getParent());
219+
warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
216220
}
217-
auto warpsPerCTA = distributedLayout.getWarpsPerCTA();
218221
assert(sizePerThread.size() == threadsPerWarp.size() &&
219222
sizePerThread.size() == warpsPerCTA.size());
220223
SmallVector<unsigned> shape;

0 commit comments

Comments
 (0)