Skip to content

Commit 6af4f88

Browse files
Fix 3xTF32 precision issues (#4934)
See #4603. --- Here is precision/performance comparison before/after this PR (on A100): <details> <summary>The script used for testing</summary> ``` import pandas as pd import torch import torch.utils.benchmark as benchmark import triton import triton.language as tl import cutlass dtype = torch.float32 device = "cuda" loss = torch.nn.MSELoss() def cutlass_mm(a, b): assert a.shape[1] == b.shape[0], "Incompatible dimensions" m, n = a.shape[0], b.shape[1] d = torch.empty((m, n), dtype=a.dtype, device=a.device) plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor) plan.math_operation = cutlass.MathOperation.multiply_add_fast_f32 alpha = 1 beta = 0 plan.run(a, b, d, d, alpha, beta, print_module=False) return d @triton.jit def triton_mm_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) accumulator = tl.dot(a, b, accumulator, input_precision="tf32x3") a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk c = accumulator.to(tl.float32) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) def triton_mm(a, b): assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.is_contiguous(), "Matrix A must be contiguous" M, K = a.shape K, N = b.shape c = torch.empty((M, N), device=a.device, dtype=torch.float32) BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 32 grid = lambda META: ( triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) triton_mm_kernel[grid]( a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, GROUP_SIZE_M=8, num_warps=8, num_stages=3, ) return c torch.manual_seed(1234) dims = [] triton_3xtf32_loss = [] cutlass_3xtf32_loss = [] for m in range(256, 4096, 128): n = k = m a = torch.randn((m, k), dtype=dtype, device=device) b = torch.randn((k, n), dtype=dtype, device=device) allow_tf32_saved = torch.backends.cuda.matmul.allow_tf32 torch.backends.cuda.matmul.allow_tf32 = False d_ref = torch.mm(a, b) torch.backends.cuda.matmul.allow_tf32 = allow_tf32_saved d_triton_3xtf32 = triton_mm(a, b) d_cutlass_3xtf32 = cutlass_mm(a, b) dims.append(m) triton_3xtf32_loss.append(loss(d_triton_3xtf32, d_ref).item()) cutlass_3xtf32_loss.append(loss(d_cutlass_3xtf32, d_ref).item()) df = pd.DataFrame( { "dims": dims, "Triton 3xTF32 loss": triton_3xtf32_loss, "CUTLASS 3xTF32 loss": cutlass_3xtf32_loss, } ) print(df) print() results = [] label = "Triton 3xTF32 vs. CUTLASS 3xTF32 latency" for m in range(256, 4096, 128): sub_label = f"m = n = k = {m:5d}" a = torch.randn((m, k), dtype=dtype, device=device) b = torch.randn((k, n), dtype=dtype, device=device) measurement = benchmark.Timer( stmt="mm(a, b)", globals={ "mm": triton_mm, "a": a, "b": b, }, label=label, sub_label=sub_label, description="Triton 3xTF32", ).blocked_autorange() results.append(measurement) measurement = benchmark.Timer( stmt="mm(a, b)", globals={ "mm": cutlass_mm, "a": a, "b": b, }, label=label, sub_label=sub_label, description="CUTLASS", ).blocked_autorange() results.append(measurement) compare = benchmark.Compare(results) compare.print() ``` </details> <details> <summary>Test script output for vanilla Triton build</summary> ``` dims Triton 3xTF32 loss CUTLASS 3xTF32 loss 0 256 1.366855e-09 5.235101e-11 1 384 4.742662e-09 8.836381e-11 2 512 1.157405e-08 1.270737e-10 3 640 2.254077e-08 1.644706e-10 4 768 3.873695e-08 2.048905e-10 5 896 6.212847e-08 2.524800e-10 6 1024 9.253924e-08 2.843547e-10 7 1152 1.318329e-07 3.507732e-10 8 1280 1.823635e-07 7.997096e-10 9 1408 2.423697e-07 4.624160e-10 10 1536 3.152084e-07 5.258877e-10 11 1664 3.999571e-07 5.849541e-10 12 1792 5.002328e-07 6.518351e-10 13 1920 6.167757e-07 1.014158e-09 14 2048 7.500014e-07 1.800559e-09 15 2176 8.983116e-07 2.005555e-09 16 2304 1.064476e-06 2.212916e-09 17 2432 1.255128e-06 2.445486e-09 18 2560 1.461378e-06 2.680297e-09 19 2688 1.688605e-06 2.921828e-09 20 2816 1.943802e-06 3.181862e-09 21 2944 2.224484e-06 3.454009e-09 22 3072 2.519756e-06 3.732411e-09 23 3200 2.850649e-06 4.019436e-09 24 3328 3.207230e-06 4.322690e-09 25 3456 3.598114e-06 4.644620e-09 26 3584 4.016068e-06 4.967569e-09 27 3712 4.458372e-06 5.296403e-09 28 3840 4.932218e-06 5.642412e-09 29 3968 5.452913e-06 6.006925e-09 [----- Triton 3xTF32 vs. CUTLASS 3xTF32 latency ----] | Triton 3xTF32 | CUTLASS 1 threads: ------------------------------------------ m = n = k = 256 | 525.7 | 1059.2 m = n = k = 384 | 526.6 | 1098.4 m = n = k = 512 | 1047.8 | 1385.1 m = n = k = 640 | 1049.2 | 1606.8 m = n = k = 768 | 1050.2 | 1601.3 m = n = k = 896 | 1565.3 | 1700.0 m = n = k = 1024 | 1571.0 | 1712.2 m = n = k = 1152 | 1572.6 | 1912.5 m = n = k = 1280 | 1573.4 | 1907.2 m = n = k = 1408 | 2092.4 | 2248.6 m = n = k = 1536 | 2094.0 | 2260.1 m = n = k = 1664 | 2095.1 | 2242.8 m = n = k = 1792 | 2612.2 | 2580.8 m = n = k = 1920 | 2615.5 | 2611.6 m = n = k = 2048 | 2617.1 | 2582.8 m = n = k = 2176 | 2618.3 | 2696.4 m = n = k = 2304 | 3136.8 | 2903.1 m = n = k = 2432 | 3139.2 | 2915.2 m = n = k = 2560 | 3144.3 | 2915.3 m = n = k = 2688 | 3649.2 | 3270.4 m = n = k = 2816 | 3660.1 | 3241.2 m = n = k = 2944 | 3661.4 | 3331.5 m = n = k = 3072 | 3664.0 | 3048.8 m = n = k = 3200 | 4180.4 | 3379.3 m = n = k = 3328 | 4182.6 | 3395.0 m = n = k = 3456 | 4184.5 | 3384.3 m = n = k = 3584 | 4690.9 | 3712.1 m = n = k = 3712 | 4707.7 | 3921.7 m = n = k = 3840 | 4706.4 | 3919.7 m = n = k = 3968 | 4708.1 | 3707.0 Times are in microseconds (us). ``` </details> <details> <summary>Test script output for Triton build with this PR applied</summary> ``` dims Triton 3xTF32 loss CUTLASS 3xTF32 loss 0 256 9.949744e-12 5.235101e-11 1 384 2.407365e-11 8.836381e-11 2 512 3.835959e-11 1.270737e-10 3 640 5.498505e-11 1.644706e-10 4 768 7.436918e-11 2.048905e-10 5 896 9.789199e-11 2.524800e-10 6 1024 1.072674e-10 2.843547e-10 7 1152 1.520337e-10 3.507732e-10 8 1280 5.775638e-10 7.997096e-10 9 1408 2.184144e-10 4.624160e-10 10 1536 2.571353e-10 5.258877e-10 11 1664 2.963491e-10 5.849541e-10 12 1792 3.402902e-10 6.518351e-10 13 1920 6.804675e-10 1.014158e-09 14 2048 1.443346e-09 1.800559e-09 15 2176 1.625424e-09 2.005555e-09 16 2304 1.813113e-09 2.212916e-09 17 2432 2.018629e-09 2.445486e-09 18 2560 2.232485e-09 2.680297e-09 19 2688 2.452671e-09 2.921828e-09 20 2816 2.689190e-09 3.181862e-09 21 2944 2.937780e-09 3.454009e-09 22 3072 3.193837e-09 3.732411e-09 23 3200 3.460724e-09 4.019436e-09 24 3328 3.738940e-09 4.322690e-09 25 3456 4.038074e-09 4.644620e-09 26 3584 4.338085e-09 4.967569e-09 27 3712 4.644735e-09 5.296403e-09 28 3840 4.969717e-09 5.642412e-09 29 3968 5.309353e-09 6.006925e-09 [----- Triton 3xTF32 vs. CUTLASS 3xTF32 latency ----] | Triton 3xTF32 | CUTLASS 1 threads: ------------------------------------------ m = n = k = 256 | 701.4 | 1058.7 m = n = k = 384 | 704.7 | 1103.7 m = n = k = 512 | 1392.3 | 1394.9 m = n = k = 640 | 1393.9 | 1387.5 m = n = k = 768 | 1395.9 | 1389.7 m = n = k = 896 | 2077.6 | 1739.9 m = n = k = 1024 | 2088.4 | 1730.4 m = n = k = 1152 | 2100.7 | 1737.3 m = n = k = 1280 | 2094.9 | 1759.5 m = n = k = 1408 | 2790.6 | 2258.8 m = n = k = 1536 | 2786.3 | 2332.9 m = n = k = 1664 | 2788.9 | 2251.7 m = n = k = 1792 | 3470.9 | 2618.0 m = n = k = 1920 | 3479.4 | 2596.3 m = n = k = 2048 | 3480.7 | 2407.4 m = n = k = 2176 | 3498.1 | 2541.2 m = n = k = 2304 | 4177.2 | 2941.1 m = n = k = 2432 | 4177.3 | 2765.4 m = n = k = 2560 | 4180.9 | 2932.8 m = n = k = 2688 | 4864.3 | 3100.1 m = n = k = 2816 | 4871.8 | 3039.4 m = n = k = 2944 | 4873.1 | 3240.2 m = n = k = 3072 | 4875.7 | 3060.8 m = n = k = 3200 | 5580.7 | 3638.5 m = n = k = 3328 | 5573.4 | 3442.0 m = n = k = 3456 | 5572.4 | 3583.3 m = n = k = 3584 | 6259.6 | 3902.5 m = n = k = 3712 | 6263.7 | 3909.3 m = n = k = 3840 | 6263.8 | 3721.8 m = n = k = 3968 | 6268.2 | 3941.8 Times are in microseconds (us). ``` </details>
1 parent fc8add9 commit 6af4f88

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ class TF32x3 : public OpRewritePattern<DotOp> {
4545
ArrayRef<Value>{value})
4646
.getResult()[0];
4747
};
48+
auto zeroLike = [&](Value c) -> Value {
49+
return rewriter.create<SplatOp>(
50+
dotOp->getLoc(), c.getType(),
51+
rewriter.create<arith::ConstantOp>(dotOp->getLoc(),
52+
rewriter.getF32FloatAttr(0)));
53+
};
54+
auto add = [&](Value a, Value b) -> Value {
55+
return rewriter.create<arith::AddFOp>(dotOp.getLoc(), a, b);
56+
};
4857
auto sub = [&](Value a, Value b) -> Value {
4958
return rewriter.create<arith::SubFOp>(dotOp.getLoc(), a, b);
5059
};
@@ -60,11 +69,15 @@ class TF32x3 : public OpRewritePattern<DotOp> {
6069
auto bBig = f32ToTF32(dotOp.getB());
6170
auto bSmall = sub(dotOp.getB(), bBig);
6271

63-
auto dot1 = dot(aSmall, bBig, dotOp.getC());
72+
auto zero = zeroLike(dotOp.getC());
73+
74+
auto dot1 = dot(aSmall, bBig, zero);
6475
auto dot2 = dot(aBig, bSmall, dot1);
6576
auto dot3 = dot(aBig, bBig, dot2);
6677

67-
rewriter.replaceOp(dotOp, dot3);
78+
auto sum = add(dot3, dotOp.getC());
79+
80+
rewriter.replaceOp(dotOp, sum);
6881
return success();
6982
}
7083
};

0 commit comments

Comments
 (0)