@@ -223,6 +223,7 @@ emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout,
223223 unsigned executionSize = dpasLayout.getExecutionSize ();
224224 unsigned opsPerChannel = dpasLayout.getOpsPerChannel ();
225225
226+ unsigned rank = shape.size ();
226227 unsigned numRowsPerPackedValue = 0u , numColsPerPackedValue = 0u ;
227228 unsigned numColsPerLaneForPackedValue = 0u , numOpsPerPackedValue = 0u ;
228229 switch (opIdx) {
@@ -232,7 +233,7 @@ emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout,
232233 SmallVector<unsigned > shapeA = dpasLayout.getShapeA ();
233234 // Unlike the operand B, to pack the value to i16 for scalar bit width <=16.
234235 numOpsPerPackedValue = opsPerChannel == 4 ? 2 : 1 ;
235- unsigned packedColNum = shapeA[1 ] / numOpsPerPackedValue;
236+ unsigned packedColNum = shapeA[rank - 1 ] / numOpsPerPackedValue;
236237 // Each value name represent multiple rows if warpSize > packedColNum
237238 numRowsPerPackedValue = mlir::ceil (warpSize, packedColNum);
238239 numColsPerPackedValue = std::min (warpSize, packedColNum);
@@ -256,9 +257,9 @@ emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout,
256257 int64_t numRepK = numReps[opIdx ? 1 : 2 ];
257258
258259 ArrayRef<unsigned > repCluster = dpasLayout.getRepCluster ();
259- unsigned repClusterSize = repCluster[opIdx];
260+ unsigned repClusterSize = repCluster[opIdx ? rank - 1 : rank - 2 ];
260261
261- for (unsigned dimOuter = 0 ; dimOuter < numRepOuter; ++dimOuter )
262+ for (unsigned repOuter = 0 ; repOuter < numRepOuter; ++repOuter )
262263 for (unsigned k = 0 ; k < numRepK; ++k)
263264 for (unsigned rep = 0 ; rep < repClusterSize; ++rep) {
264265 for (unsigned elemId = 0 ; elemId < numElemPerInstPerThread; ++elemId) {
@@ -268,9 +269,9 @@ emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout,
268269 (opIdx == 0 ) ? elemId % numOpsPerPackedValue : 0 ;
269270 unsigned packedElemId = elemId / numOpsPerPackedValue;
270271 unsigned repRowIndex =
271- shapePerCTATile[0 ] * (opIdx == 0 ? dimOuter : k);
272+ shapePerCTATile[rank - 2 ] * (opIdx == 0 ? repOuter : k);
272273 unsigned repColIndex =
273- shapePerCTATile[1 ] * (opIdx == 0 ? k : dimOuter );
274+ shapePerCTATile[rank - 1 ] * (opIdx == 0 ? k : repOuter );
274275 unsigned repClusterRowIndex = opIdx == 0 ? rep * instShape[0 ] : 0 ;
275276 unsigned repClusterColIndex = opIdx == 0 ? 0 : rep * instShape[1 ];
276277 unsigned packedElemRowIndex =
@@ -279,10 +280,17 @@ emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout,
279280 unsigned packedElemColIndex =
280281 (packedElemId % numColsPerLaneForPackedValue) *
281282 numColsPerPackedValue;
282- offsets.push_back ({repRowIndex + repClusterRowIndex +
283- packedElemRowIndex + opsRowIndex,
284- repColIndex + repClusterColIndex +
285- packedElemColIndex + opsColIndex});
283+ if (rank == 3 )
284+ offsets.push_back ({0 ,
285+ repRowIndex + repClusterRowIndex +
286+ packedElemRowIndex + opsRowIndex,
287+ repColIndex + repClusterColIndex +
288+ packedElemColIndex + opsColIndex});
289+ else
290+ offsets.push_back ({repRowIndex + repClusterRowIndex +
291+ packedElemRowIndex + opsRowIndex,
292+ repColIndex + repClusterColIndex +
293+ packedElemColIndex + opsColIndex});
286294 }
287295 }
288296
@@ -560,6 +568,7 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,
560568
561569inline SmallVector<SmallVector<unsigned >>
562570emitOffsetForLayout (Attribute layout, RankedTensorType type) {
571+ std::cout << " ~! emitOffsetForLayout\n " ;
563572 if (auto dpasLayout = dyn_cast<DpasEncodingAttr>(layout))
564573 return emitOffsetForDpasLayout (dpasLayout, type);
565574 if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout))
0 commit comments