-
Notifications
You must be signed in to change notification settings - Fork 76
Merge OpenAI Triton commit d207894
#2523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Co-authored-by: Mario Lezcano Casado <[email protected]>
This adds a missing exception to the warp size and fixes dot test for m or n > 32 when using wmma.
…4891) Note that the current implementation uses `DotOperandEncodingAttr::getWarpsPerCTA`, which was buggy for cases where the warps are not of the form `[numWarps, 1]` or `[1, numWarps]`. This PR bundles a fix for this issue. We will activate its use for a subset of `DotOperandEncoding`s in a PR coming soon.
Currently PRs only get CI when the base branch is main or starts with `dev-` but this changes it to run on all PRs which works better for when using `git-pr-chain` where the base branch usually isn't main.
- The epilogue ramp-down indexing must start at zero or greater (total_iterations - max_stage) to ensure alignment with the prologue ramp-up stages. - If total_iterations < max_stage, the trailing stages will be masked. This commit mirrors upstream llvm/llvm-project#112418 and adds a functional test for correctness with num_stages=1,2,3,4.
To fix the following warnings: ```bash Warning: hook id `ruff` uses deprecated stage names (commit, push) which will be removed in a future version. run: `pre-commit migrate-config` to automatically fix this. Warning: hook id `yapf` uses deprecated stage names (commit, push) which will be removed in a future version. run: `pre-commit migrate-config` to automatically fix this. Warning: hook id `clang-format` uses deprecated stage names (commit, push) which will be removed in a future version. run: `pre-commit migrate-config` to automatically fix this. ``` CI run, for example: https://github.com/triton-lang/triton/actions/runs/11346522167/job/31555782632#step:6:51 I just used `pre-commit migrate-config` command. Signed-off-by: Anatoly Myachev <[email protected]>
These optimizations break internal workloads
This PR includes triton-lang/triton#4891 and triton-lang/triton#4895. I will rebase once those have landed. It includes a number of hacks to work around bugs in `DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be easy to grep for. @Jokeren is working on a comprehensive revamp of `DotOperandEncodingAttr` which will get rid of all these. triton-lang/triton#4895 is the first step in this direction.
…hMLIRArgs=true (#4931) The predicate wasn't being added to the instruction representation when `onlyAttachMLIRArgs` was set to true
See #4603. --- Here is precision/performance comparison before/after this PR (on A100): <details> <summary>The script used for testing</summary> ``` import pandas as pd import torch import torch.utils.benchmark as benchmark import triton import triton.language as tl import cutlass dtype = torch.float32 device = "cuda" loss = torch.nn.MSELoss() def cutlass_mm(a, b): assert a.shape[1] == b.shape[0], "Incompatible dimensions" m, n = a.shape[0], b.shape[1] d = torch.empty((m, n), dtype=a.dtype, device=a.device) plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor) plan.math_operation = cutlass.MathOperation.multiply_add_fast_f32 alpha = 1 beta = 0 plan.run(a, b, d, d, alpha, beta, print_module=False) return d @triton.jit def triton_mm_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) accumulator = tl.dot(a, b, accumulator, input_precision="tf32x3") a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk c = accumulator.to(tl.float32) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) def triton_mm(a, b): assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.is_contiguous(), "Matrix A must be contiguous" M, K = a.shape K, N = b.shape c = torch.empty((M, N), device=a.device, dtype=torch.float32) BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 32 grid = lambda META: ( triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) triton_mm_kernel[grid]( a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, GROUP_SIZE_M=8, num_warps=8, num_stages=3, ) return c torch.manual_seed(1234) dims = [] triton_3xtf32_loss = [] cutlass_3xtf32_loss = [] for m in range(256, 4096, 128): n = k = m a = torch.randn((m, k), dtype=dtype, device=device) b = torch.randn((k, n), dtype=dtype, device=device) allow_tf32_saved = torch.backends.cuda.matmul.allow_tf32 torch.backends.cuda.matmul.allow_tf32 = False d_ref = torch.mm(a, b) torch.backends.cuda.matmul.allow_tf32 = allow_tf32_saved d_triton_3xtf32 = triton_mm(a, b) d_cutlass_3xtf32 = cutlass_mm(a, b) dims.append(m) triton_3xtf32_loss.append(loss(d_triton_3xtf32, d_ref).item()) cutlass_3xtf32_loss.append(loss(d_cutlass_3xtf32, d_ref).item()) df = pd.DataFrame( { "dims": dims, "Triton 3xTF32 loss": triton_3xtf32_loss, "CUTLASS 3xTF32 loss": cutlass_3xtf32_loss, } ) print(df) print() results = [] label = "Triton 3xTF32 vs. CUTLASS 3xTF32 latency" for m in range(256, 4096, 128): sub_label = f"m = n = k = {m:5d}" a = torch.randn((m, k), dtype=dtype, device=device) b = torch.randn((k, n), dtype=dtype, device=device) measurement = benchmark.Timer( stmt="mm(a, b)", globals={ "mm": triton_mm, "a": a, "b": b, }, label=label, sub_label=sub_label, description="Triton 3xTF32", ).blocked_autorange() results.append(measurement) measurement = benchmark.Timer( stmt="mm(a, b)", globals={ "mm": cutlass_mm, "a": a, "b": b, }, label=label, sub_label=sub_label, description="CUTLASS", ).blocked_autorange() results.append(measurement) compare = benchmark.Compare(results) compare.print() ``` </details> <details> <summary>Test script output for vanilla Triton build</summary> ``` dims Triton 3xTF32 loss CUTLASS 3xTF32 loss 0 256 1.366855e-09 5.235101e-11 1 384 4.742662e-09 8.836381e-11 2 512 1.157405e-08 1.270737e-10 3 640 2.254077e-08 1.644706e-10 4 768 3.873695e-08 2.048905e-10 5 896 6.212847e-08 2.524800e-10 6 1024 9.253924e-08 2.843547e-10 7 1152 1.318329e-07 3.507732e-10 8 1280 1.823635e-07 7.997096e-10 9 1408 2.423697e-07 4.624160e-10 10 1536 3.152084e-07 5.258877e-10 11 1664 3.999571e-07 5.849541e-10 12 1792 5.002328e-07 6.518351e-10 13 1920 6.167757e-07 1.014158e-09 14 2048 7.500014e-07 1.800559e-09 15 2176 8.983116e-07 2.005555e-09 16 2304 1.064476e-06 2.212916e-09 17 2432 1.255128e-06 2.445486e-09 18 2560 1.461378e-06 2.680297e-09 19 2688 1.688605e-06 2.921828e-09 20 2816 1.943802e-06 3.181862e-09 21 2944 2.224484e-06 3.454009e-09 22 3072 2.519756e-06 3.732411e-09 23 3200 2.850649e-06 4.019436e-09 24 3328 3.207230e-06 4.322690e-09 25 3456 3.598114e-06 4.644620e-09 26 3584 4.016068e-06 4.967569e-09 27 3712 4.458372e-06 5.296403e-09 28 3840 4.932218e-06 5.642412e-09 29 3968 5.452913e-06 6.006925e-09 [----- Triton 3xTF32 vs. CUTLASS 3xTF32 latency ----] | Triton 3xTF32 | CUTLASS 1 threads: ------------------------------------------ m = n = k = 256 | 525.7 | 1059.2 m = n = k = 384 | 526.6 | 1098.4 m = n = k = 512 | 1047.8 | 1385.1 m = n = k = 640 | 1049.2 | 1606.8 m = n = k = 768 | 1050.2 | 1601.3 m = n = k = 896 | 1565.3 | 1700.0 m = n = k = 1024 | 1571.0 | 1712.2 m = n = k = 1152 | 1572.6 | 1912.5 m = n = k = 1280 | 1573.4 | 1907.2 m = n = k = 1408 | 2092.4 | 2248.6 m = n = k = 1536 | 2094.0 | 2260.1 m = n = k = 1664 | 2095.1 | 2242.8 m = n = k = 1792 | 2612.2 | 2580.8 m = n = k = 1920 | 2615.5 | 2611.6 m = n = k = 2048 | 2617.1 | 2582.8 m = n = k = 2176 | 2618.3 | 2696.4 m = n = k = 2304 | 3136.8 | 2903.1 m = n = k = 2432 | 3139.2 | 2915.2 m = n = k = 2560 | 3144.3 | 2915.3 m = n = k = 2688 | 3649.2 | 3270.4 m = n = k = 2816 | 3660.1 | 3241.2 m = n = k = 2944 | 3661.4 | 3331.5 m = n = k = 3072 | 3664.0 | 3048.8 m = n = k = 3200 | 4180.4 | 3379.3 m = n = k = 3328 | 4182.6 | 3395.0 m = n = k = 3456 | 4184.5 | 3384.3 m = n = k = 3584 | 4690.9 | 3712.1 m = n = k = 3712 | 4707.7 | 3921.7 m = n = k = 3840 | 4706.4 | 3919.7 m = n = k = 3968 | 4708.1 | 3707.0 Times are in microseconds (us). ``` </details> <details> <summary>Test script output for Triton build with this PR applied</summary> ``` dims Triton 3xTF32 loss CUTLASS 3xTF32 loss 0 256 9.949744e-12 5.235101e-11 1 384 2.407365e-11 8.836381e-11 2 512 3.835959e-11 1.270737e-10 3 640 5.498505e-11 1.644706e-10 4 768 7.436918e-11 2.048905e-10 5 896 9.789199e-11 2.524800e-10 6 1024 1.072674e-10 2.843547e-10 7 1152 1.520337e-10 3.507732e-10 8 1280 5.775638e-10 7.997096e-10 9 1408 2.184144e-10 4.624160e-10 10 1536 2.571353e-10 5.258877e-10 11 1664 2.963491e-10 5.849541e-10 12 1792 3.402902e-10 6.518351e-10 13 1920 6.804675e-10 1.014158e-09 14 2048 1.443346e-09 1.800559e-09 15 2176 1.625424e-09 2.005555e-09 16 2304 1.813113e-09 2.212916e-09 17 2432 2.018629e-09 2.445486e-09 18 2560 2.232485e-09 2.680297e-09 19 2688 2.452671e-09 2.921828e-09 20 2816 2.689190e-09 3.181862e-09 21 2944 2.937780e-09 3.454009e-09 22 3072 3.193837e-09 3.732411e-09 23 3200 3.460724e-09 4.019436e-09 24 3328 3.738940e-09 4.322690e-09 25 3456 4.038074e-09 4.644620e-09 26 3584 4.338085e-09 4.967569e-09 27 3712 4.644735e-09 5.296403e-09 28 3840 4.969717e-09 5.642412e-09 29 3968 5.309353e-09 6.006925e-09 [----- Triton 3xTF32 vs. CUTLASS 3xTF32 latency ----] | Triton 3xTF32 | CUTLASS 1 threads: ------------------------------------------ m = n = k = 256 | 701.4 | 1058.7 m = n = k = 384 | 704.7 | 1103.7 m = n = k = 512 | 1392.3 | 1394.9 m = n = k = 640 | 1393.9 | 1387.5 m = n = k = 768 | 1395.9 | 1389.7 m = n = k = 896 | 2077.6 | 1739.9 m = n = k = 1024 | 2088.4 | 1730.4 m = n = k = 1152 | 2100.7 | 1737.3 m = n = k = 1280 | 2094.9 | 1759.5 m = n = k = 1408 | 2790.6 | 2258.8 m = n = k = 1536 | 2786.3 | 2332.9 m = n = k = 1664 | 2788.9 | 2251.7 m = n = k = 1792 | 3470.9 | 2618.0 m = n = k = 1920 | 3479.4 | 2596.3 m = n = k = 2048 | 3480.7 | 2407.4 m = n = k = 2176 | 3498.1 | 2541.2 m = n = k = 2304 | 4177.2 | 2941.1 m = n = k = 2432 | 4177.3 | 2765.4 m = n = k = 2560 | 4180.9 | 2932.8 m = n = k = 2688 | 4864.3 | 3100.1 m = n = k = 2816 | 4871.8 | 3039.4 m = n = k = 2944 | 4873.1 | 3240.2 m = n = k = 3072 | 4875.7 | 3060.8 m = n = k = 3200 | 5580.7 | 3638.5 m = n = k = 3328 | 5573.4 | 3442.0 m = n = k = 3456 | 5572.4 | 3583.3 m = n = k = 3584 | 6259.6 | 3902.5 m = n = k = 3712 | 6263.7 | 3909.3 m = n = k = 3840 | 6263.8 | 3721.8 m = n = k = 3968 | 6268.2 | 3941.8 Times are in microseconds (us). ``` </details>
Otherwise, we may include `libGPUInstrumentationTestLib.so` into the commit
f9688abd997364
chengjunlu
approved these changes
Oct 22, 2024
This reverts commit f9688ab.
6cb59f2 to
3ffaa2f
Compare
d997364d207894
This was referenced Oct 22, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR change the Triton base from fa229d1 to d207894 (Oct 16).
Pass rate: 98.98%->98.99%
Please do not squash and merge this PR.