@@ -288,60 +288,71 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
288288 return rewriter.notifyMatchFailure (
289289 op, " NYI. srcTy and/or dstTy don't implement LLs yet" );
290290 }
291+ LinearLayout srcLayout =
292+ *toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
293+ LinearLayout dstLayout =
294+ *toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
295+
296+ StringAttr kBlock = str_attr (" block" );
297+ StringAttr kWarp = str_attr (" warp" );
298+ StringAttr kLane = str_attr (" lane" );
299+ StringAttr kRegister = str_attr (" register" );
291300
292301 assert (to_vector (conversion->getInDimNames ()) ==
293302 to_vector (conversion->getOutDimNames ()));
294303 auto dims = conversion->getInDimNames ();
295- if (llvm::is_contained (dims, str_attr ( " block " ) )) {
304+ if (llvm::is_contained (dims, kBlock )) {
296305 // Case 1: Transfer between values in different CTAs.
297306 // This requires moving values through distributed shared memory.
298307 return rewriter.notifyMatchFailure (
299308 op, " NYI: Transfer between different CTAs" );
300- } else if (llvm::is_contained (dims, str_attr ( " warp " ) )) {
309+ } else if (llvm::is_contained (dims, kWarp )) {
301310 // Case 2: Transfer between values in the same CTA, in which case we move
302311 // values through shared memory.
303- LinearLayout srcLayout =
304- *toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
305- LinearLayout dstLayout =
306- *toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
307312 return transferWithinBlock (op, srcLayout, dstLayout, adaptor, rewriter);
308- } else if (llvm::is_contained (dims, str_attr ( " lane " ) )) {
313+ } else if (llvm::is_contained (dims, kLane )) {
309314 // Case 3. Transfer between values in the same warp, in which case we try
310315 // to move values using warp shuffles, though if the pattern is
311316 // complicated enough we may fall back to using shared memory
312317 // TODO(Keren): implement warp shuffle instead of using the general
313318 // approach that uses shared memory
314- LinearLayout srcLayout =
315- *toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
316- LinearLayout dstLayout =
317- *toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
318319 return transferWithinBlock (op, srcLayout, dstLayout, adaptor, rewriter);
319- } else if (llvm::is_contained (dims, str_attr (" register" ))) {
320+ } else if (llvm::is_contained (dims, kRegister ) ||
321+ dstLayout.getInDimSize (kRegister ) !=
322+ srcLayout.getInDimSize (kRegister )) {
320323 // Case 4. Transfer between values in the same thread, in which case we
321324 // simply reorder the elements of adaptor.getSrc().
322- return transferWithinThread (op, *conversion, adaptor, rewriter);
325+ return transferWithinThread (
326+ op, dstLayout.getFreeVariableMasks ()[kRegister ],
327+ dstLayout.getInDimSize (kRegister ), *conversion, adaptor, rewriter);
323328 } else {
324- // The two layouts are equivalent. We should probably remove these in
325- // RemoveLayoutConversion.
329+ // Cast 5. The two layouts are equivalent. We should probably remove
330+ // these in RemoveLayoutConversion.
326331 rewriter.replaceOp (op, adaptor.getSrc ());
327332 return success ();
328333 }
329334 }
330335
331336 LogicalResult
332- transferWithinThread (ConvertLayoutOp op, const LinearLayout &conversion ,
333- OpAdaptor adaptor,
337+ transferWithinThread (ConvertLayoutOp op, int32_t regMasks, int32_t numRegs ,
338+ const LinearLayout &conversion, OpAdaptor adaptor,
334339 ConversionPatternRewriter &rewriter) const {
335340 MLIRContext *ctx = op.getContext ();
336341 auto loc = op.getLoc ();
337342 StringAttr kRegister = str_attr (" register" );
338343 assert (!cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
339344
340345 auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
341- SmallVector<Value> outVals;
342- outVals.resize (conversion.getInDimSize (kRegister ));
343- for (int i = 0 ; i < conversion.getInDimSize (kRegister ); i++) {
344- auto srcIdx = conversion.apply ({{kRegister , i}}).begin ()->second ;
346+ SmallVector<Value> outVals (numRegs);
347+ for (int i = 0 ; i < outVals.size (); i++) {
348+ // Remove free masks from the register index
349+ // For example, if idx = 0b00111, and masks = 0b00100, then we get
350+ // 0b00011. It means that register 7 (0b111) has the same value as
351+ // register 3 (0b011).
352+ auto idx = i & (~regMasks);
353+ auto srcIdx = conversion.hasInDim (kRegister )
354+ ? conversion.apply ({{kRegister , idx}}).begin ()->second
355+ : idx;
345356 outVals[i] = inVals[srcIdx];
346357 }
347358 Value result = packLLElements (loc, getTypeConverter (), outVals, rewriter,
@@ -372,6 +383,22 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
372383 }
373384 return true ;
374385 }
386+ if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
387+ if (auto nvidiaMma =
388+ dyn_cast<NvidiaMmaEncodingAttr>(dotOperand.getParent ())) {
389+ if (product (getCTAsPerCGA (nvidiaMma)) > 1 ) {
390+ return false ;
391+ }
392+ if (useLegacyMMAConversion) {
393+ return false ;
394+ }
395+ // FIXME [Dot LL]
396+ // Enabling LL path for buggy kWidth path
397+ bool largeKWidth =
398+ dotOperand.getKWidth () * dstTy.getElementTypeBitWidth () > 64 ;
399+ return largeKWidth && nvidiaMma.isAmpere ();
400+ }
401+ }
375402 if (isa<BlockedEncodingAttr>(layout)) {
376403 return true ;
377404 }
@@ -431,6 +458,22 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
431458 }
432459 }
433460
461+ // FIXME [Dot LL]
462+ // We know it's just for largeKWidth case in Ampere
463+ // In this case, we need to pack the outputs into i32
464+ if (isa<DotOperandEncodingAttr>(dstTy.getEncoding ())) {
465+ auto concat = [&](Value a, Value b) {
466+ return or_ (zext (i32_ty, bitcast (a, i16_ty)),
467+ shl (zext (i32_ty, bitcast (b, i16_ty)), i32_val (16 )));
468+ };
469+
470+ SmallVector<Value> outVals32 (outVals.size () / 2 );
471+ for (int i = 0 ; i < outVals32.size (); ++i) {
472+ outVals32[i] = concat (outVals[2 * i], outVals[2 * i + 1 ]);
473+ }
474+ outVals = outVals32;
475+ }
476+
434477 Value result = packLLElements (loc, getTypeConverter (), outVals, rewriter,
435478 op.getType ());
436479 rewriter.replaceOp (op, result);
0 commit comments