Skip to content

Conversation

@whitneywhtsang
Copy link
Contributor

@whitneywhtsang whitneywhtsang commented Oct 22, 2024

This PR change the Triton base from fa229d1 to d207894 (Oct 16).
Pass rate: 98.98%->98.99%

Please do not squash and merge this PR.

Jokeren and others added 13 commits October 15, 2024 09:16
This adds a missing exception to the warp size and fixes dot test for m
or n > 32 when using wmma.
…4891)

Note that the current implementation uses
`DotOperandEncodingAttr::getWarpsPerCTA`, which was buggy for cases
where the warps are not of the form `[numWarps, 1]` or `[1, numWarps]`.
This PR bundles a fix for this issue.

We will activate its use for a subset of `DotOperandEncoding`s in a PR
coming soon.
Currently PRs only get CI when the base branch is main or starts with
`dev-` but this changes it to run on all PRs which works better for when
using `git-pr-chain` where the base branch usually isn't main.
- The epilogue ramp-down indexing must start at zero or greater
  (total_iterations - max_stage) to ensure alignment with the prologue
  ramp-up stages.
- If total_iterations < max_stage, the trailing stages will be masked.

This commit mirrors upstream llvm/llvm-project#112418
and adds a functional test for correctness with num_stages=1,2,3,4.
To fix the following warnings:
```bash
Warning:  hook id `ruff` uses deprecated stage names (commit, push) which will be removed in a future version.  run: `pre-commit migrate-config` to automatically fix this.
Warning:  hook id `yapf` uses deprecated stage names (commit, push) which will be removed in a future version.  run: `pre-commit migrate-config` to automatically fix this.
Warning:  hook id `clang-format` uses deprecated stage names (commit, push) which will be removed in a future version.  run: `pre-commit migrate-config` to automatically fix this.
```

CI run, for example:
https://github.com/triton-lang/triton/actions/runs/11346522167/job/31555782632#step:6:51

I just used `pre-commit migrate-config` command.

Signed-off-by: Anatoly Myachev <[email protected]>
These optimizations break internal workloads
This PR includes triton-lang/triton#4891 and
triton-lang/triton#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang/triton#4895 is the first step in
this direction.
…hMLIRArgs=true (#4931)

The predicate wasn't being added to the instruction representation when
`onlyAttachMLIRArgs` was set to true
See #4603.

---

Here is precision/performance comparison before/after this PR (on A100):

<details>
 <summary>The script used for testing</summary>

```
import pandas as pd

import torch
import torch.utils.benchmark as benchmark

import triton
import triton.language as tl

import cutlass


dtype = torch.float32
device = "cuda"
loss = torch.nn.MSELoss()


def cutlass_mm(a, b):
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    m, n = a.shape[0], b.shape[1]
    d = torch.empty((m, n), dtype=a.dtype, device=a.device)

    plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)
    plan.math_operation = cutlass.MathOperation.multiply_add_fast_f32

    alpha = 1
    beta = 0
    plan.run(a, b, d, d, alpha, beta, print_module=False)

    return d


@triton.jit
def triton_mm_kernel(
    a_ptr,
    b_ptr,
    c_ptr,
    M,
    N,
    K,
    stride_am,
    stride_ak,
    stride_bk,
    stride_bn,
    stride_cm,
    stride_cn,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        accumulator = tl.dot(a, b, accumulator, input_precision="tf32x3")
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    c = accumulator.to(tl.float32)

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


def triton_mm(a, b):
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    K, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=torch.float32)
    BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 128, 32
    grid = lambda META: (
        triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
    )
    triton_mm_kernel[grid](
        a,
        b,
        c,
        M,
        N,
        K,
        a.stride(0),
        a.stride(1),
        b.stride(0),
        b.stride(1),
        c.stride(0),
        c.stride(1),
        BLOCK_SIZE_M=BLOCK_SIZE_M,
        BLOCK_SIZE_N=BLOCK_SIZE_N,
        BLOCK_SIZE_K=BLOCK_SIZE_K,
        GROUP_SIZE_M=8,
        num_warps=8,
        num_stages=3,
    )
    return c


torch.manual_seed(1234)

dims = []
triton_3xtf32_loss = []
cutlass_3xtf32_loss = []
for m in range(256, 4096, 128):
    n = k = m

    a = torch.randn((m, k), dtype=dtype, device=device)
    b = torch.randn((k, n), dtype=dtype, device=device)

    allow_tf32_saved = torch.backends.cuda.matmul.allow_tf32
    torch.backends.cuda.matmul.allow_tf32 = False
    d_ref = torch.mm(a, b)
    torch.backends.cuda.matmul.allow_tf32 = allow_tf32_saved

    d_triton_3xtf32 = triton_mm(a, b)
    d_cutlass_3xtf32 = cutlass_mm(a, b)

    dims.append(m)
    triton_3xtf32_loss.append(loss(d_triton_3xtf32, d_ref).item())
    cutlass_3xtf32_loss.append(loss(d_cutlass_3xtf32, d_ref).item())

df = pd.DataFrame(
    {
        "dims": dims,
        "Triton 3xTF32 loss": triton_3xtf32_loss,
        "CUTLASS 3xTF32 loss": cutlass_3xtf32_loss,
    }
)
print(df)

print()

results = []
label = "Triton 3xTF32 vs. CUTLASS 3xTF32 latency"
for m in range(256, 4096, 128):
    sub_label = f"m = n = k = {m:5d}"

    a = torch.randn((m, k), dtype=dtype, device=device)
    b = torch.randn((k, n), dtype=dtype, device=device)

    measurement = benchmark.Timer(
        stmt="mm(a, b)",
        globals={
            "mm": triton_mm,
            "a": a,
            "b": b,
        },
        label=label,
        sub_label=sub_label,
        description="Triton 3xTF32",
    ).blocked_autorange()
    results.append(measurement)

    measurement = benchmark.Timer(
        stmt="mm(a, b)",
        globals={
            "mm": cutlass_mm,
            "a": a,
            "b": b,
        },
        label=label,
        sub_label=sub_label,
        description="CUTLASS",
    ).blocked_autorange()
    results.append(measurement)

compare = benchmark.Compare(results)
compare.print()
```
</details>

<details>
<summary>Test script output for vanilla Triton build</summary>

```
    dims  Triton 3xTF32 loss  CUTLASS 3xTF32 loss
0    256        1.366855e-09         5.235101e-11
1    384        4.742662e-09         8.836381e-11
2    512        1.157405e-08         1.270737e-10
3    640        2.254077e-08         1.644706e-10
4    768        3.873695e-08         2.048905e-10
5    896        6.212847e-08         2.524800e-10
6   1024        9.253924e-08         2.843547e-10
7   1152        1.318329e-07         3.507732e-10
8   1280        1.823635e-07         7.997096e-10
9   1408        2.423697e-07         4.624160e-10
10  1536        3.152084e-07         5.258877e-10
11  1664        3.999571e-07         5.849541e-10
12  1792        5.002328e-07         6.518351e-10
13  1920        6.167757e-07         1.014158e-09
14  2048        7.500014e-07         1.800559e-09
15  2176        8.983116e-07         2.005555e-09
16  2304        1.064476e-06         2.212916e-09
17  2432        1.255128e-06         2.445486e-09
18  2560        1.461378e-06         2.680297e-09
19  2688        1.688605e-06         2.921828e-09
20  2816        1.943802e-06         3.181862e-09
21  2944        2.224484e-06         3.454009e-09
22  3072        2.519756e-06         3.732411e-09
23  3200        2.850649e-06         4.019436e-09
24  3328        3.207230e-06         4.322690e-09
25  3456        3.598114e-06         4.644620e-09
26  3584        4.016068e-06         4.967569e-09
27  3712        4.458372e-06         5.296403e-09
28  3840        4.932218e-06         5.642412e-09
29  3968        5.452913e-06         6.006925e-09

[----- Triton 3xTF32 vs. CUTLASS 3xTF32 latency ----]
                         |  Triton 3xTF32  |  CUTLASS
1 threads: ------------------------------------------
      m = n = k =   256  |       525.7     |   1059.2
      m = n = k =   384  |       526.6     |   1098.4
      m = n = k =   512  |      1047.8     |   1385.1
      m = n = k =   640  |      1049.2     |   1606.8
      m = n = k =   768  |      1050.2     |   1601.3
      m = n = k =   896  |      1565.3     |   1700.0
      m = n = k =  1024  |      1571.0     |   1712.2
      m = n = k =  1152  |      1572.6     |   1912.5
      m = n = k =  1280  |      1573.4     |   1907.2
      m = n = k =  1408  |      2092.4     |   2248.6
      m = n = k =  1536  |      2094.0     |   2260.1
      m = n = k =  1664  |      2095.1     |   2242.8
      m = n = k =  1792  |      2612.2     |   2580.8
      m = n = k =  1920  |      2615.5     |   2611.6
      m = n = k =  2048  |      2617.1     |   2582.8
      m = n = k =  2176  |      2618.3     |   2696.4
      m = n = k =  2304  |      3136.8     |   2903.1
      m = n = k =  2432  |      3139.2     |   2915.2
      m = n = k =  2560  |      3144.3     |   2915.3
      m = n = k =  2688  |      3649.2     |   3270.4
      m = n = k =  2816  |      3660.1     |   3241.2
      m = n = k =  2944  |      3661.4     |   3331.5
      m = n = k =  3072  |      3664.0     |   3048.8
      m = n = k =  3200  |      4180.4     |   3379.3
      m = n = k =  3328  |      4182.6     |   3395.0
      m = n = k =  3456  |      4184.5     |   3384.3
      m = n = k =  3584  |      4690.9     |   3712.1
      m = n = k =  3712  |      4707.7     |   3921.7
      m = n = k =  3840  |      4706.4     |   3919.7
      m = n = k =  3968  |      4708.1     |   3707.0

Times are in microseconds (us).
```
</details>

<details>
<summary>Test script output for Triton build with this PR
applied</summary>

```
dims  Triton 3xTF32 loss  CUTLASS 3xTF32 loss
0    256        9.949744e-12         5.235101e-11
1    384        2.407365e-11         8.836381e-11
2    512        3.835959e-11         1.270737e-10
3    640        5.498505e-11         1.644706e-10
4    768        7.436918e-11         2.048905e-10
5    896        9.789199e-11         2.524800e-10
6   1024        1.072674e-10         2.843547e-10
7   1152        1.520337e-10         3.507732e-10
8   1280        5.775638e-10         7.997096e-10
9   1408        2.184144e-10         4.624160e-10
10  1536        2.571353e-10         5.258877e-10
11  1664        2.963491e-10         5.849541e-10
12  1792        3.402902e-10         6.518351e-10
13  1920        6.804675e-10         1.014158e-09
14  2048        1.443346e-09         1.800559e-09
15  2176        1.625424e-09         2.005555e-09
16  2304        1.813113e-09         2.212916e-09
17  2432        2.018629e-09         2.445486e-09
18  2560        2.232485e-09         2.680297e-09
19  2688        2.452671e-09         2.921828e-09
20  2816        2.689190e-09         3.181862e-09
21  2944        2.937780e-09         3.454009e-09
22  3072        3.193837e-09         3.732411e-09
23  3200        3.460724e-09         4.019436e-09
24  3328        3.738940e-09         4.322690e-09
25  3456        4.038074e-09         4.644620e-09
26  3584        4.338085e-09         4.967569e-09
27  3712        4.644735e-09         5.296403e-09
28  3840        4.969717e-09         5.642412e-09
29  3968        5.309353e-09         6.006925e-09

[----- Triton 3xTF32 vs. CUTLASS 3xTF32 latency ----]
                         |  Triton 3xTF32  |  CUTLASS
1 threads: ------------------------------------------
      m = n = k =   256  |       701.4     |   1058.7
      m = n = k =   384  |       704.7     |   1103.7
      m = n = k =   512  |      1392.3     |   1394.9
      m = n = k =   640  |      1393.9     |   1387.5
      m = n = k =   768  |      1395.9     |   1389.7
      m = n = k =   896  |      2077.6     |   1739.9
      m = n = k =  1024  |      2088.4     |   1730.4
      m = n = k =  1152  |      2100.7     |   1737.3
      m = n = k =  1280  |      2094.9     |   1759.5
      m = n = k =  1408  |      2790.6     |   2258.8
      m = n = k =  1536  |      2786.3     |   2332.9
      m = n = k =  1664  |      2788.9     |   2251.7
      m = n = k =  1792  |      3470.9     |   2618.0
      m = n = k =  1920  |      3479.4     |   2596.3
      m = n = k =  2048  |      3480.7     |   2407.4
      m = n = k =  2176  |      3498.1     |   2541.2
      m = n = k =  2304  |      4177.2     |   2941.1
      m = n = k =  2432  |      4177.3     |   2765.4
      m = n = k =  2560  |      4180.9     |   2932.8
      m = n = k =  2688  |      4864.3     |   3100.1
      m = n = k =  2816  |      4871.8     |   3039.4
      m = n = k =  2944  |      4873.1     |   3240.2
      m = n = k =  3072  |      4875.7     |   3060.8
      m = n = k =  3200  |      5580.7     |   3638.5
      m = n = k =  3328  |      5573.4     |   3442.0
      m = n = k =  3456  |      5572.4     |   3583.3
      m = n = k =  3584  |      6259.6     |   3902.5
      m = n = k =  3712  |      6263.7     |   3909.3
      m = n = k =  3840  |      6263.8     |   3721.8
      m = n = k =  3968  |      6268.2     |   3941.8

Times are in microseconds (us).
```
</details>
Otherwise, we may include `libGPUInstrumentationTestLib.so` into the
commit
@whitneywhtsang whitneywhtsang self-assigned this Oct 22, 2024
@whitneywhtsang whitneywhtsang changed the title Merge OpenAI Triton commit f9688ab Merge OpenAI Triton commit d997364 Oct 22, 2024
@whitneywhtsang whitneywhtsang marked this pull request as ready for review October 22, 2024 08:59
@whitneywhtsang whitneywhtsang merged commit bbba43a into main Oct 22, 2024
4 checks passed
@whitneywhtsang whitneywhtsang deleted the whitneywhtsang/merge3 branch October 22, 2024 10:06
@whitneywhtsang whitneywhtsang changed the title Merge OpenAI Triton commit d997364 Merge OpenAI Triton commit d207894 Oct 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.