-
Notifications
You must be signed in to change notification settings - Fork 75
Commit 6af4f88
authored
Fix 3xTF32 precision issues (#4934)
See #4603.
---
Here is precision/performance comparison before/after this PR (on A100):
<details>
<summary>The script used for testing</summary>
```
import pandas as pd
import torch
import torch.utils.benchmark as benchmark
import triton
import triton.language as tl
import cutlass
dtype = torch.float32
device = "cuda"
loss = torch.nn.MSELoss()
def cutlass_mm(a, b):
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
m, n = a.shape[0], b.shape[1]
d = torch.empty((m, n), dtype=a.dtype, device=a.device)
plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)
plan.math_operation = cutlass.MathOperation.multiply_add_fast_f32
alpha = 1
beta = 0
plan.run(a, b, d, d, alpha, beta, print_module=False)
return d
@triton.jit
def triton_mm_kernel(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator, input_precision="tf32x3")
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def triton_mm(a, b):
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 32
grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
triton_mm_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=8,
num_warps=8,
num_stages=3,
)
return c
torch.manual_seed(1234)
dims = []
triton_3xtf32_loss = []
cutlass_3xtf32_loss = []
for m in range(256, 4096, 128):
n = k = m
a = torch.randn((m, k), dtype=dtype, device=device)
b = torch.randn((k, n), dtype=dtype, device=device)
allow_tf32_saved = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
d_ref = torch.mm(a, b)
torch.backends.cuda.matmul.allow_tf32 = allow_tf32_saved
d_triton_3xtf32 = triton_mm(a, b)
d_cutlass_3xtf32 = cutlass_mm(a, b)
dims.append(m)
triton_3xtf32_loss.append(loss(d_triton_3xtf32, d_ref).item())
cutlass_3xtf32_loss.append(loss(d_cutlass_3xtf32, d_ref).item())
df = pd.DataFrame(
{
"dims": dims,
"Triton 3xTF32 loss": triton_3xtf32_loss,
"CUTLASS 3xTF32 loss": cutlass_3xtf32_loss,
}
)
print(df)
print()
results = []
label = "Triton 3xTF32 vs. CUTLASS 3xTF32 latency"
for m in range(256, 4096, 128):
sub_label = f"m = n = k = {m:5d}"
a = torch.randn((m, k), dtype=dtype, device=device)
b = torch.randn((k, n), dtype=dtype, device=device)
measurement = benchmark.Timer(
stmt="mm(a, b)",
globals={
"mm": triton_mm,
"a": a,
"b": b,
},
label=label,
sub_label=sub_label,
description="Triton 3xTF32",
).blocked_autorange()
results.append(measurement)
measurement = benchmark.Timer(
stmt="mm(a, b)",
globals={
"mm": cutlass_mm,
"a": a,
"b": b,
},
label=label,
sub_label=sub_label,
description="CUTLASS",
).blocked_autorange()
results.append(measurement)
compare = benchmark.Compare(results)
compare.print()
```
</details>
<details>
<summary>Test script output for vanilla Triton build</summary>
```
dims Triton 3xTF32 loss CUTLASS 3xTF32 loss
0 256 1.366855e-09 5.235101e-11
1 384 4.742662e-09 8.836381e-11
2 512 1.157405e-08 1.270737e-10
3 640 2.254077e-08 1.644706e-10
4 768 3.873695e-08 2.048905e-10
5 896 6.212847e-08 2.524800e-10
6 1024 9.253924e-08 2.843547e-10
7 1152 1.318329e-07 3.507732e-10
8 1280 1.823635e-07 7.997096e-10
9 1408 2.423697e-07 4.624160e-10
10 1536 3.152084e-07 5.258877e-10
11 1664 3.999571e-07 5.849541e-10
12 1792 5.002328e-07 6.518351e-10
13 1920 6.167757e-07 1.014158e-09
14 2048 7.500014e-07 1.800559e-09
15 2176 8.983116e-07 2.005555e-09
16 2304 1.064476e-06 2.212916e-09
17 2432 1.255128e-06 2.445486e-09
18 2560 1.461378e-06 2.680297e-09
19 2688 1.688605e-06 2.921828e-09
20 2816 1.943802e-06 3.181862e-09
21 2944 2.224484e-06 3.454009e-09
22 3072 2.519756e-06 3.732411e-09
23 3200 2.850649e-06 4.019436e-09
24 3328 3.207230e-06 4.322690e-09
25 3456 3.598114e-06 4.644620e-09
26 3584 4.016068e-06 4.967569e-09
27 3712 4.458372e-06 5.296403e-09
28 3840 4.932218e-06 5.642412e-09
29 3968 5.452913e-06 6.006925e-09
[----- Triton 3xTF32 vs. CUTLASS 3xTF32 latency ----]
| Triton 3xTF32 | CUTLASS
1 threads: ------------------------------------------
m = n = k = 256 | 525.7 | 1059.2
m = n = k = 384 | 526.6 | 1098.4
m = n = k = 512 | 1047.8 | 1385.1
m = n = k = 640 | 1049.2 | 1606.8
m = n = k = 768 | 1050.2 | 1601.3
m = n = k = 896 | 1565.3 | 1700.0
m = n = k = 1024 | 1571.0 | 1712.2
m = n = k = 1152 | 1572.6 | 1912.5
m = n = k = 1280 | 1573.4 | 1907.2
m = n = k = 1408 | 2092.4 | 2248.6
m = n = k = 1536 | 2094.0 | 2260.1
m = n = k = 1664 | 2095.1 | 2242.8
m = n = k = 1792 | 2612.2 | 2580.8
m = n = k = 1920 | 2615.5 | 2611.6
m = n = k = 2048 | 2617.1 | 2582.8
m = n = k = 2176 | 2618.3 | 2696.4
m = n = k = 2304 | 3136.8 | 2903.1
m = n = k = 2432 | 3139.2 | 2915.2
m = n = k = 2560 | 3144.3 | 2915.3
m = n = k = 2688 | 3649.2 | 3270.4
m = n = k = 2816 | 3660.1 | 3241.2
m = n = k = 2944 | 3661.4 | 3331.5
m = n = k = 3072 | 3664.0 | 3048.8
m = n = k = 3200 | 4180.4 | 3379.3
m = n = k = 3328 | 4182.6 | 3395.0
m = n = k = 3456 | 4184.5 | 3384.3
m = n = k = 3584 | 4690.9 | 3712.1
m = n = k = 3712 | 4707.7 | 3921.7
m = n = k = 3840 | 4706.4 | 3919.7
m = n = k = 3968 | 4708.1 | 3707.0
Times are in microseconds (us).
```
</details>
<details>
<summary>Test script output for Triton build with this PR
applied</summary>
```
dims Triton 3xTF32 loss CUTLASS 3xTF32 loss
0 256 9.949744e-12 5.235101e-11
1 384 2.407365e-11 8.836381e-11
2 512 3.835959e-11 1.270737e-10
3 640 5.498505e-11 1.644706e-10
4 768 7.436918e-11 2.048905e-10
5 896 9.789199e-11 2.524800e-10
6 1024 1.072674e-10 2.843547e-10
7 1152 1.520337e-10 3.507732e-10
8 1280 5.775638e-10 7.997096e-10
9 1408 2.184144e-10 4.624160e-10
10 1536 2.571353e-10 5.258877e-10
11 1664 2.963491e-10 5.849541e-10
12 1792 3.402902e-10 6.518351e-10
13 1920 6.804675e-10 1.014158e-09
14 2048 1.443346e-09 1.800559e-09
15 2176 1.625424e-09 2.005555e-09
16 2304 1.813113e-09 2.212916e-09
17 2432 2.018629e-09 2.445486e-09
18 2560 2.232485e-09 2.680297e-09
19 2688 2.452671e-09 2.921828e-09
20 2816 2.689190e-09 3.181862e-09
21 2944 2.937780e-09 3.454009e-09
22 3072 3.193837e-09 3.732411e-09
23 3200 3.460724e-09 4.019436e-09
24 3328 3.738940e-09 4.322690e-09
25 3456 4.038074e-09 4.644620e-09
26 3584 4.338085e-09 4.967569e-09
27 3712 4.644735e-09 5.296403e-09
28 3840 4.969717e-09 5.642412e-09
29 3968 5.309353e-09 6.006925e-09
[----- Triton 3xTF32 vs. CUTLASS 3xTF32 latency ----]
| Triton 3xTF32 | CUTLASS
1 threads: ------------------------------------------
m = n = k = 256 | 701.4 | 1058.7
m = n = k = 384 | 704.7 | 1103.7
m = n = k = 512 | 1392.3 | 1394.9
m = n = k = 640 | 1393.9 | 1387.5
m = n = k = 768 | 1395.9 | 1389.7
m = n = k = 896 | 2077.6 | 1739.9
m = n = k = 1024 | 2088.4 | 1730.4
m = n = k = 1152 | 2100.7 | 1737.3
m = n = k = 1280 | 2094.9 | 1759.5
m = n = k = 1408 | 2790.6 | 2258.8
m = n = k = 1536 | 2786.3 | 2332.9
m = n = k = 1664 | 2788.9 | 2251.7
m = n = k = 1792 | 3470.9 | 2618.0
m = n = k = 1920 | 3479.4 | 2596.3
m = n = k = 2048 | 3480.7 | 2407.4
m = n = k = 2176 | 3498.1 | 2541.2
m = n = k = 2304 | 4177.2 | 2941.1
m = n = k = 2432 | 4177.3 | 2765.4
m = n = k = 2560 | 4180.9 | 2932.8
m = n = k = 2688 | 4864.3 | 3100.1
m = n = k = 2816 | 4871.8 | 3039.4
m = n = k = 2944 | 4873.1 | 3240.2
m = n = k = 3072 | 4875.7 | 3060.8
m = n = k = 3200 | 5580.7 | 3638.5
m = n = k = 3328 | 5573.4 | 3442.0
m = n = k = 3456 | 5572.4 | 3583.3
m = n = k = 3584 | 6259.6 | 3902.5
m = n = k = 3712 | 6263.7 | 3909.3
m = n = k = 3840 | 6263.8 | 3721.8
m = n = k = 3968 | 6268.2 | 3941.8
Times are in microseconds (us).
```
</details>1 parent fc8add9 commit 6af4f88Copy full SHA for 6af4f88
File tree
Expand file treeCollapse file tree
1 file changed
+15
-2
lines changedOpen diff view settings
Filter options
- lib/Dialect/TritonGPU/Transforms
Expand file treeCollapse file tree
1 file changed
+15
-2
lines changedOpen diff view settings
Collapse file
lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp
Copy file name to clipboardExpand all lines: lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp+15-2Lines changed: 15 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
45 | 45 | | |
46 | 46 | | |
47 | 47 | | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
48 | 57 | | |
49 | 58 | | |
50 | 59 | | |
| |||
60 | 69 | | |
61 | 70 | | |
62 | 71 | | |
63 | | - | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
64 | 75 | | |
65 | 76 | | |
66 | 77 | | |
67 | | - | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
68 | 81 | | |
69 | 82 | | |
70 | 83 | | |
| |||
0 commit comments