@@ -228,33 +228,11 @@ sharedToLinearLayoutAMDRotating(ArrayRef<int64_t> shape,
228228 return combineCtaCgaWithShape (ctaLayout, shared.getCTALayout (), shape);
229229}
230230
231- } // namespace
232-
233- LinearLayout nvmmaSharedToLinearLayout (ArrayRef<int64_t > shape,
234- NVMMASharedEncodingAttr shared,
231+ // Returns the layout of a single core matrix which tiles the nvmma layout
232+ LinearLayout getCoreMatrixLinearLayout (NVMMASharedEncodingAttr shared,
235233 bool disableSwizzle) {
236- MLIRContext *ctx = shared.getContext ();
237- int rank = shape.size ();
238- auto shapePerCTA = getShapePerCTA (shared, shape);
239- if (rank == 1 ) {
240- // TODO: Not sure if this is correct.
241- return combineCtaCgaWithShape (
242- LinearLayout::identity1D (shapePerCTA[0 ], S (" offset" ), S (" dim0" )),
243- shared.getCTALayout (), shape);
244- }
245- // Construct bases for a the layout's 2-dimensional tile.
246- assert (rank >= 2 );
247- int batchDims = rank - 2 ;
234+ auto *ctx = shared.getContext ();
248235
249- // Collapse all the outer dim into one. We will then create a layout for this
250- // shape and reshape it to the original shape.
251- std::array<int64_t , 2 > collapsedShapePerCTA = {shapePerCTA[batchDims],
252- shapePerCTA[batchDims + 1 ]};
253- for (int i = 0 ; i < batchDims; i++)
254- collapsedShapePerCTA[0 ] *= shapePerCTA[i];
255- if (shared.getTransposed ()) {
256- std::swap (collapsedShapePerCTA[0 ], collapsedShapePerCTA[1 ]);
257- }
258236 int elemBitWidth = shared.getElementBitWidth ();
259237 int tileWidthBytes = shared.getSwizzlingByteWidth ();
260238 int vec = 128 / elemBitWidth;
@@ -273,25 +251,9 @@ LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
273251
274252 int tileRows = 8 ;
275253 int tileCols = 8 * tileWidthBytes / elemBitWidth;
276- bool isFp4Padded = false ;
277- if (auto sharedMMALayout =
278- dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(shared)) {
279- if (sharedMMALayout.getFp4Padded ()) {
280- isFp4Padded = true ;
281- }
282- }
254+ bool isFp4Padded = shared.getFp4Padded ();
283255 int packingFactor = isFp4Padded ? 2 : 1 ;
284256
285- if (collapsedShapePerCTA[1 ] * packingFactor < tileCols ||
286- collapsedShapePerCTA[0 ] < tileRows) {
287- llvm::errs () << " Illegal shared layout; expected collapsed shapePerCTA to "
288- " be at least ["
289- << tileRows << " , " << tileCols << " ], collapsedShapePerCTA: ["
290- << collapsedShapePerCTA[0 ] << " , " << collapsedShapePerCTA[1 ]
291- << " ]\n " ;
292- llvm::report_fatal_error (" Illegal shared layout" );
293- }
294-
295257 std::vector<std::vector<int >> bases2D;
296258 for (int col = 1 ; col < tileCols; col *= 2 ) {
297259 if (isFp4Padded) {
@@ -309,30 +271,75 @@ LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
309271 for (int row = 1 ; row < tileRows; row *= 2 ) {
310272 if (disableSwizzle) {
311273 bases2D.push_back ({row, 0 });
312- continue ;
313- }
314- if (isFp4Padded) {
274+ } else if (isFp4Padded) {
315275 int colPadded = vec * ((row / perPhase) % maxPhase);
316276 int colPacked = colPadded / 16 * 8 + colPadded % 8 ;
317277 bases2D.push_back ({row, colPacked});
318278 } else {
319279 bases2D.push_back ({row, vec * ((row / perPhase) % maxPhase)});
320280 }
321281 }
282+ auto outDimNames = standardOutDimNames (ctx, 2 );
283+ auto kRow = outDimNames[1 ];
284+ auto kCol = outDimNames[0 ];
285+ LinearLayout tileLayout =
286+ LinearLayout ({{S (" offset" ), bases2D}}, {kRow , kCol });
287+ return tileLayout;
288+ }
289+
290+ } // namespace
291+
292+ LinearLayout nvmmaSharedToLinearLayout (ArrayRef<int64_t > shape,
293+ NVMMASharedEncodingAttr shared,
294+ bool disableSwizzle) {
295+ MLIRContext *ctx = shared.getContext ();
296+ int rank = shape.size ();
297+ auto shapePerCTA = getShapePerCTA (shared, shape);
298+ if (rank == 1 ) {
299+ // TODO: Not sure if this is correct.
300+ return combineCtaCgaWithShape (
301+ LinearLayout::identity1D (shapePerCTA[0 ], S (" offset" ), S (" dim0" )),
302+ shared.getCTALayout (), shape);
303+ }
304+ // Construct bases for a the layout's 2-dimensional tile.
305+ assert (rank >= 2 );
306+ int batchDims = rank - 2 ;
322307
323- // Then distribute the remaining rows.
324- for (int row = tileRows; row < collapsedShapePerCTA[0 ]; row *= 2 ) {
325- bases2D.push_back ({row, 0 });
308+ // Collapse all the outer dim into one. We will then create a layout for this
309+ // shape and reshape it to the original shape.
310+ std::array<int64_t , 2 > collapsedShapePerCTA{shapePerCTA[batchDims],
311+ shapePerCTA[batchDims + 1 ]};
312+ for (int i = 0 ; i < batchDims; i++)
313+ collapsedShapePerCTA[0 ] *= shapePerCTA[i];
314+ if (shared.getTransposed ()) {
315+ std::swap (collapsedShapePerCTA[0 ], collapsedShapePerCTA[1 ]);
326316 }
327317
318+ auto tileLayout = getCoreMatrixLinearLayout (shared, disableSwizzle);
328319 auto outDimNames = standardOutDimNames (ctx, 2 );
329- std::reverse (outDimNames.begin (), outDimNames.end ());
330- LinearLayout tileLayout = LinearLayout ({{S (" offset" ), bases2D}}, outDimNames);
331- // Expand the layout to convert the whole shape per CTA.
332- llvm::SmallDenseMap<StringAttr, int64_t > namedShape;
333- namedShape[outDimNames[0 ]] = collapsedShapePerCTA[0 ];
334- namedShape[outDimNames[1 ]] = collapsedShapePerCTA[1 ];
335- tileLayout = ensureLayoutNotSmallerThan (tileLayout, namedShape);
320+ auto kRow = outDimNames[1 ];
321+ auto kCol = outDimNames[0 ];
322+ auto tileRows = tileLayout.getOutDimSize (kRow );
323+ auto tileCols = tileLayout.getOutDimSize (kCol );
324+
325+ int packingFactor = shared.getFp4Padded () ? 2 : 1 ;
326+ if (collapsedShapePerCTA[1 ] * packingFactor < tileCols ||
327+ collapsedShapePerCTA[0 ] < tileRows) {
328+ llvm::errs () << " Illegal shared layout; expected collapsed shapePerCTA to "
329+ " be at least ["
330+ << tileRows << " , " << (tileCols / packingFactor)
331+ << " ], collapsedShapePerCTA: [" << collapsedShapePerCTA[0 ]
332+ << " , " << collapsedShapePerCTA[1 ] << " ]\n " ;
333+ llvm::report_fatal_error (" Illegal shared layout" );
334+ }
335+
336+ // Distribute the remaining rows and cols.
337+ auto kOffset = S (" offset" );
338+ auto layout = tileLayout;
339+ layout *= LinearLayout::identity1D (collapsedShapePerCTA[0 ] / tileRows,
340+ kOffset , kRow );
341+ layout *= LinearLayout::identity1D (collapsedShapePerCTA[1 ] / tileCols,
342+ kOffset , kCol );
336343
337344 // Reshape the layout to the N-D pre-transposed shape per CTA.
338345 SmallVector<int64_t > maybeTransposedShapePerCTA = shapePerCTA;
@@ -344,8 +351,7 @@ LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
344351 maybeTransposedShapePerCTA.begin () + 1 ,
345352 maybeTransposedShapePerCTA.end ());
346353 }
347- auto reshapedLayout =
348- reshapeLayout (ctx, tileLayout, maybeTransposedShapePerCTA);
354+ auto reshapedLayout = reshapeLayout (ctx, layout, maybeTransposedShapePerCTA);
349355
350356 if (shared.getTransposed ()) {
351357 SmallVector<int > order = {rank - 1 };
0 commit comments