Skip to content

Commit e82dfd9

Browse files
authored
[BACKEND] Minor Bugfixes for SharedToDotOperand MMAv3 (#5030)
Two bugfixes following triton-lang/triton#5009. - When `BLOCK_M=64` and `num_warps > 4`, the order of warps for DotOpEncoded tensor should be M-major instead of N-major, since WGMMA expects the 4 warps in each warp group to be stacked along the M dimension. - Should use `mmaBitwidth` instead of `bitwidth` when calculating `numRep` in `SharedToDotOperandMMAv2OrV3`. This was missed in a bad rebase. @lezcano I encountered these bugs when attempting to locally test the [DotOp hoisting PR](triton-lang/triton#5003) after rebasing (they normally would be caught by `test_core.py` but that path was not yet enabled in the last PR). With these fixes added, I was able to successfully validate against pytorch.
1 parent 04d655e commit e82dfd9

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -659,15 +659,15 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc,
659659

660660
int kWidth = encoding.getKWidth();
661661
auto numRep = mmaLayout.getMMAv2OrV3RepForOperand(
662-
shapePerCTA, bitwidth, kWidth, encoding.getOpIdx());
662+
shapePerCTA, mmaBitwidth, kWidth, encoding.getOpIdx());
663663

664664
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
665-
auto order = triton::gpu::getOrder(mmaLayout);
665+
auto warpOrder = mmaLayout.getWarpOrder();
666666
Value warp = udiv(thread, i32_val(32));
667667
Value lane = urem(thread, i32_val(32));
668668

669669
SmallVector<Value> multiDimWarpId =
670-
delinearize(rewriter, loc, warp, warpsPerCTA, order);
670+
delinearize(rewriter, loc, warp, warpsPerCTA, warpOrder);
671671
Value warpB = urem(multiDimWarpId[0], i32_val(shapePerCTA[0]));
672672
int warpsPerTile;
673673
Value warpM = urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / 16));

0 commit comments

Comments
 (0)