Skip to content

Commit 64b232e

Browse files
Dewei-Wang-shwhitneywhtsang
authored andcommitted
Fix to get correct warps (#2539)
1 parent 80700bf commit 64b232e

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/DistributeToWarps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ SmallVector<Value> distributeOffset(const SmallVector<Value> &oldOffsets,
116116
RankedTensorType tensorType, Value warpId,
117117
OpBuilder b, Location loc) {
118118
Attribute layout = tensorType.getEncoding();
119+
if (auto dotEncoding = dyn_cast<ttg::DotOperandEncodingAttr>(layout))
120+
layout = dotEncoding.getParent();
119121
const SmallVector<unsigned> &warpsPerCTA = ttg::getWarpsPerCTA(layout);
120122
size_t dims = warpsPerCTA.size();
121123
assert(dims <= 2 && "no more than 2D shape");

0 commit comments

Comments
 (0)