Skip to content

Commit 933cefc

Browse files
authored
[TUTORIAL] Adjust rand number range for matmul tutorial (#7505)
This PR fixes the "Unit Test" in `03-matrix-multiplication.py` on MI300x and MI350x GPUs. `torch.randn((512, 512), device=DEVICE, dtype=torch.float16)` can generate relatively large absolute number in the input which may lead to larger sums having a larger absolute roundoff error as the exponent grows. The unit test passes on AMD MI250 and Nvidia H100 as well.
1 parent 55db9b8 commit 933cefc

File tree

1 file changed

+4
-12
lines changed

1 file changed

+4
-12
lines changed

python/tutorials/03-matrix-multiplication.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,6 @@ def is_cuda():
161161
return triton.runtime.driver.active.get_current_target().backend == "cuda"
162162

163163

164-
def is_hip_cdna2():
165-
target = triton.runtime.driver.active.get_current_target()
166-
return target.backend == 'hip' and target.arch == 'gfx90a'
167-
168-
169164
def get_cuda_autotune_config():
170165
return [
171166
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
@@ -364,17 +359,14 @@ def matmul(a, b, activation=""):
364359
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS).
365360

366361
torch.manual_seed(0)
367-
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
368-
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
362+
a = torch.rand((512, 512), device=DEVICE, dtype=torch.float16) - 0.5
363+
b = torch.rand((512, 512), device=DEVICE, dtype=torch.float16) - 0.5
369364
triton_output = matmul(a, b)
370365
torch_output = torch.matmul(a, b)
371366
print(f"triton_output_with_fp16_inputs={triton_output}")
372367
print(f"torch_output_with_fp16_inputs={torch_output}")
373-
# Bigger tolerance for AMD CDNA2 devices.
374-
# CDNA2 devices use reduced precision fp16 and bf16 and flush input and
375-
# output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
376-
rtol = 1e-2 if is_hip_cdna2() else 0
377-
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
368+
369+
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0):
378370
print("✅ Triton and Torch match")
379371
else:
380372
print("❌ Triton and Torch differ")

0 commit comments

Comments
 (0)