@@ -164,8 +164,9 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout,
164164 SmallVector<unsigned > sizePerThreads = getSizePerThread (dpasLayout);
165165 ArrayRef<unsigned > repCluster = dpasLayout.getRepCluster ();
166166 size_t rank = repCluster.size ();
167- SmallVector<unsigned > sizePerDPASInst = {sizePerThreads[0 ] / repCluster[0 ],
168- sizePerThreads[1 ] / repCluster[1 ]};
167+ SmallVector<unsigned > sizePerDPASInst = {
168+ sizePerThreads[rank - 2 ] / repCluster[rank - 2 ],
169+ sizePerThreads[rank - 1 ] / repCluster[rank - 1 ]};
169170
170171 unsigned rowsPerElem = dpasLayout.getSubGroupSize () / instShapeC[1 ];
171172 unsigned colsPerElem = 1 ;
@@ -176,15 +177,19 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout,
176177 for (unsigned elemId = 0 ; elemId < elemNumberPerRep; ++elemId) {
177178 // Follows the C++ order for the dpas layout.
178179 SmallVector<unsigned > repOffset = {
179- (repId / repCluster[1 ]) * instShapeC[0 ],
180- (repId % repCluster[1 ]) * instShapeC[1 ]};
180+ (repId / repCluster[rank - 1 ]) * instShapeC[0 ],
181+ (repId % repCluster[rank - 1 ]) * instShapeC[1 ]};
181182
182183 SmallVector<unsigned > elemOffset = {
183184 (elemId / sizePerDPASInst[1 ]) * rowsPerElem,
184185 (elemId % sizePerDPASInst[1 ]) * colsPerElem};
185186
186- offsets.push_back ({repOffset[0 ] + elemOffset[0 ] + ctaOffsetX,
187- repOffset[1 ] + elemOffset[1 ] + ctaOffsetY});
187+ if (rank == 3 )
188+ offsets.push_back ({0 , repOffset[0 ] + elemOffset[0 ] + ctaOffsetX,
189+ repOffset[1 ] + elemOffset[1 ] + ctaOffsetY});
190+ else
191+ offsets.push_back ({repOffset[0 ] + elemOffset[0 ] + ctaOffsetX,
192+ repOffset[1 ] + elemOffset[1 ] + ctaOffsetY});
188193 }
189194 }
190195}
@@ -289,9 +294,10 @@ emitOffsetForDpasLayout(const DpasEncodingAttr &dpasLayout,
289294 ArrayRef<int64_t > shape = type.getShape ();
290295 SmallVector<SmallVector<unsigned >> offsets;
291296 SmallVector<unsigned > shapePerCTA = getShapePerCTATile (dpasLayout);
297+ size_t rank = shape.size ();
292298
293- for (unsigned i = 0 ; i < shape[0 ]; i += shapePerCTA[0 ]) {
294- for (unsigned j = 0 ; j < shape[1 ]; j += shapePerCTA[1 ]) {
299+ for (unsigned i = 0 ; i < shape[rank - 2 ]; i += shapePerCTA[rank - 2 ]) {
300+ for (unsigned j = 0 ; j < shape[rank - 1 ]; j += shapePerCTA[rank - 1 ]) {
295301 emitOffsetForDpasLayoutPerCTA (dpasLayout, offsets, i, j);
296302 }
297303 }
@@ -333,13 +339,14 @@ emitBaseIndexForDotOpLayout(Location loc, RewriterBase &rewriter,
333339 size_t rank = warpShape.size ();
334340 assert (rank == shapePerCTA.size () && " Rank mismatch" );
335341 Value warpIndex =
336- (opIdx == 0 ) ? urem (multiDimWarpId[0 ],
342+ (opIdx == 0 ) ? urem (multiDimWarpId[rank - 2 ],
337343 i32_val (mlir::ceil<unsigned >(shapePerCTA[rank - 2 ],
338344 warpShape[rank - 2 ])))
339- : urem (multiDimWarpId[1 ],
345+ : urem (multiDimWarpId[rank - 1 ],
340346 i32_val (mlir::ceil<unsigned >(shapePerCTA[rank - 1 ],
341347 warpShape[rank - 1 ])));
342- Value warpOffset = mul (warpIndex, i32_val (warpShape[opIdx]));
348+ Value warpOffset =
349+ mul (warpIndex, i32_val (warpShape[opIdx ? rank - 1 : rank - 2 ]));
343350
344351 // Compute the 2-dim coordinates of the first element in the warp operated
345352 // own by this thread.
@@ -355,7 +362,7 @@ emitBaseIndexForDotOpLayout(Location loc, RewriterBase &rewriter,
355362 // Unlike the operand B, to pack the value to i16 for scalar bit width
356363 // <=16.
357364 unsigned packedOpsPerLane = opsPerChannel == 4 ? 2 : 1 ;
358- unsigned packedColNum = shapeA[1 ] / packedOpsPerLane;
365+ unsigned packedColNum = shapeA[rank - 1 ] / packedOpsPerLane;
359366 if (warpSize < packedColNum)
360367 llvm::report_fatal_error (
361368 " DpasEncodingAttr sub-group size could not "
@@ -375,12 +382,18 @@ emitBaseIndexForDotOpLayout(Location loc, RewriterBase &rewriter,
375382 laneRowIndex = mul (laneRowIndex, i32_val (opsPerChannel));
376383 laneColIndex = urem (laneId, i32_val (executionSize));
377384 } break ;
385+ default : {
386+ llvm::report_fatal_error (" Only support opIdx 1 or 0 for DotOpLayout." );
387+ }
378388 }
379389
380- auto multiDimBase =
381- (opIdx == 0 )
382- ? SmallVector<Value>{add (laneRowIndex, warpOffset), laneColIndex}
383- : SmallVector<Value>{laneRowIndex, add (laneColIndex, warpOffset)};
390+ SmallVector<Value> multiDimBase (rank);
391+ if (rank == 3 )
392+ multiDimBase[0 ] = multiDimWarpId[0 ];
393+ multiDimBase[rank - 2 ] =
394+ (opIdx == 0 ) ? add (laneRowIndex, warpOffset) : laneRowIndex;
395+ multiDimBase[rank - 1 ] =
396+ (opIdx == 0 ) ? laneColIndex : add (laneColIndex, warpOffset);
384397
385398 return multiDimBase;
386399}
@@ -394,6 +407,7 @@ emitBaseIndexForDpasLayout(Location loc, RewriterBase &rewriter,
394407 Value warpId = udiv (threadId, warpSize);
395408 Value laneId = urem (threadId, warpSize);
396409
410+ unsigned rank = type.getShape ().size ();
397411 auto warpsPerCTA = dpasLayout.getWarpsPerCTA ();
398412 ArrayRef<int64_t > shape = type.getShape ();
399413
@@ -404,19 +418,25 @@ emitBaseIndexForDpasLayout(Location loc, RewriterBase &rewriter,
404418 // Compute the 2-dim coordinates of the warp containing the tensor element
405419 // operated on by this thread.
406420 SmallVector<unsigned > warpShape = dpasLayout.getShapeC ();
407- Value rowWarpId = urem (multiDimWarpId[0 ],
408- i32_val (mlir::ceil<unsigned >(shape[0 ], warpShape[0 ])));
409- Value colWarpId = urem (multiDimWarpId[1 ],
410- i32_val (mlir::ceil<unsigned >(shape[1 ], warpShape[1 ])));
411- Value rowWarpOffset = mul (rowWarpId, i32_val (warpShape[0 ]));
412- Value colWarpOffset = mul (colWarpId, i32_val (warpShape[1 ]));
421+ Value rowWarpId =
422+ urem (multiDimWarpId[rank - 2 ],
423+ i32_val (mlir::ceil<unsigned >(shape[rank - 2 ], warpShape[rank - 2 ])));
424+ Value colWarpId =
425+ urem (multiDimWarpId[rank - 1 ],
426+ i32_val (mlir::ceil<unsigned >(shape[rank - 1 ], warpShape[rank - 1 ])));
427+ Value rowWarpOffset = mul (rowWarpId, i32_val (warpShape[rank - 2 ]));
428+ Value colWarpOffset = mul (colWarpId, i32_val (warpShape[rank - 1 ]));
413429
414430 // Compute the 2-dim coordinates of the first element in the warp operated
415431 // on by this thread.
416432 SmallVector<unsigned > threadsPerWarp = getThreadsPerWarp (dpasLayout);
417- SmallVector<Value> multiDimBase = {
418- add (udiv (laneId, i32_val (threadsPerWarp[1 ])), rowWarpOffset),
419- add (urem (laneId, i32_val (threadsPerWarp[1 ])), colWarpOffset)};
433+ SmallVector<Value> multiDimBase (rank);
434+ if (rank == 3 )
435+ multiDimBase[0 ] = multiDimWarpId[0 ];
436+ multiDimBase[rank - 2 ] =
437+ add (udiv (laneId, i32_val (threadsPerWarp[rank - 1 ])), rowWarpOffset);
438+ multiDimBase[rank - 1 ] =
439+ add (urem (laneId, i32_val (threadsPerWarp[rank - 1 ])), colWarpOffset);
420440 return multiDimBase;
421441}
422442
0 commit comments