Skip to content

Commit f36e7fc

Browse files
agron911meta-codesync[bot]
authored andcommitted
[Cherry-pick] [BACKEND] Fix matrix descriptor for no-swizzle case (#8027) (#558)
Summary: Cherry-picked from upstream OAI repository. Original Commit: 1c2e9bb Original Author: Thomas Raoux Original Date: 2025-09-02 00:07:03 -0700 Original commit message: ``` [BACKEND] Fix matrix descriptor for no-swizzle case (#8027) Also tighten the precision check to catch subtle fp8 precision diff ``` This PR was automatically cherry-picked from the upstream triton-lang/triton repository. Pull Request resolved: #558 Reviewed By: dshi7 Differential Revision: D86147986 Pulled By: agron911 fbshipit-source-id: b4a9e09d1393b5a0d08d6f866d0079ab40c711a0
1 parent 833dc71 commit f36e7fc

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ def test_simple_matmul(
211211
atol = 0.06
212212
rtol = 0.06
213213
else:
214-
atol = 0.01
215-
rtol = 0.01
214+
atol = 0.001
215+
rtol = 0.001
216216
torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol)
217217
# Make sure the mma is pipelined by checking if in the TTGIR we see two mmav5
218218
# operations. (Pipeliner will add additional mma operation by peeling the prologue.)

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,12 @@ static Value createDescriptor(ConversionPatternRewriter &rewriter, Location loc,
103103
llvm::report_fatal_error("Unsupported swizzling size.");
104104
}
105105
if (swizzling == 0) {
106-
desc.leadDimensionBaseOffset = 16 >> 4; // 16 bytes.
107-
desc.strideDimensionBaseOffset = (8 * 16) >> 4;
106+
// Because the descriptor normalizes spacing to 128-bit units, the
107+
// normalized per-element stride is 16 bytes and LBO is defined as 8×that,
108+
// i.e. 128 bytes.
109+
desc.leadDimensionBaseOffset = 128 >> 4;
110+
// Offset from first row to second row 16x16 bytes.
111+
desc.strideDimensionBaseOffset = 256 >> 4;
108112
} else {
109113
desc.leadDimensionBaseOffset = (swizzling * stride) >> 4;
110114
desc.strideDimensionBaseOffset = swizzling >> 1;

0 commit comments

Comments
 (0)