-
Notifications
You must be signed in to change notification settings - Fork 487
Add mHC fused kernels + LigerMHC API + benchmarks #1065
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
yukiu00
wants to merge
14
commits into
linkedin:main
Choose a base branch
from
yukiu00:feat/mhc-kernel
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+3,173
−0
Open
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
bb37285
mhc: add kernels, tests, benchmarks, and docs
yukiu00 fe16236
refactor mhc benches/tests
yukiu00 64343c5
Format
yukiu00 72d99ef
mhc: code quality improvements
yukiu00 498b1e6
Merge branch 'main' into feat/mhc-kernel
yukiu00 aca1fb8
mhc: address PR review feedback
yukiu00 94d5da7
Refactor MHC tests to simplify imports
yukiu00 28cf670
mhc: align benchmark with standard framework and fix convergence test…
yukiu00 5aba8e0
Remove test_mhc_mini_lm.py file and integrate MiniMHCLM class into te…
yukiu00 5238230
mhc: address remaining PR #1065 review feedback
yukiu00 7027de6
Fix type hinting in mhc.py by replacing `tuple` with `Tuple` for cons…
yukiu00 cb25136
Add LigerMHC to the transformers module exports.
yukiu00 af0e661
Remove TorchMHCCoeffs class from test_mhc.py
yukiu00 5c9eca9
mhc: code quality cleanup across ops, tests, and benchmarks
yukiu00 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,255 @@ | ||
| import os | ||
| import sys | ||
|
|
||
| import torch | ||
| import triton | ||
|
|
||
| from utils import QUANTILES | ||
| from utils import SingleBenchmarkRunInput | ||
| from utils import SingleBenchmarkRunOutput | ||
| from utils import _test_memory | ||
| from utils import parse_benchmark_script_args | ||
| from utils import run_benchmarks | ||
|
|
||
| from liger_kernel.transformers.functional import liger_mhc_coeffs | ||
| from liger_kernel.transformers.functional import liger_mhc_post_res | ||
| from liger_kernel.transformers.functional import liger_mhc_pre | ||
| from liger_kernel.utils import infer_device | ||
|
|
||
| device = infer_device() | ||
|
|
||
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) | ||
|
|
||
|
|
||
| def bench_speed_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | ||
| from test.transformers.test_mhc import mhc_coeffs_ref | ||
|
|
||
| T = input.x | ||
| B = input.extra_benchmark_config["B"] | ||
| HC = input.extra_benchmark_config["HC"] | ||
| C = input.extra_benchmark_config["C"] | ||
| sub_kernel = input.extra_benchmark_config["sub_kernel"] | ||
| tmax = input.extra_benchmark_config["tmax"] | ||
| rms_eps = input.extra_benchmark_config["rms_eps"] | ||
| pre_eps = input.extra_benchmark_config["pre_eps"] | ||
| sinkhorn_eps = input.extra_benchmark_config["sinkhorn_eps"] | ||
| post_mult = input.extra_benchmark_config["post_mult"] | ||
| provider = input.kernel_provider | ||
| mode = input.kernel_operation_mode | ||
|
|
||
| coeffs_cfg = dict(tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult) | ||
| need_grad = mode in ("backward", "full") | ||
|
|
||
| x = torch.randn(B, T, HC, C, device=device, dtype=torch.bfloat16, requires_grad=need_grad) | ||
| K, M = HC * C, HC * HC + 2 * HC | ||
| phi = (torch.randn(K, M, device=device, dtype=torch.bfloat16) * 0.02).requires_grad_(need_grad) | ||
| b_param = torch.zeros(M, device=device, dtype=torch.float32, requires_grad=need_grad) | ||
| alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad) | ||
| alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad) | ||
| alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad) | ||
|
|
||
| grad_to_none = [x, phi, b_param, alpha_pre, alpha_post, alpha_res] if need_grad else None | ||
|
|
||
| if sub_kernel == "coeffs": | ||
|
|
||
| def fwd(): | ||
| if provider == "liger": | ||
| return liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg) | ||
| return mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg) | ||
|
|
||
| def fwd_loss(): | ||
| h_pre, h_post, h_res = fwd() | ||
| return h_pre.square().mean() + h_post.square().mean() + h_res.square().mean() | ||
|
|
||
| elif sub_kernel == "pre": | ||
| with torch.no_grad(): | ||
| h_pre_c, _, _ = liger_mhc_coeffs( | ||
| x.detach(), | ||
| phi.detach(), | ||
| b_param.detach(), | ||
| alpha_pre.detach(), | ||
| alpha_post.detach(), | ||
| alpha_res.detach(), | ||
| **coeffs_cfg, | ||
| ) | ||
| h_pre_c.requires_grad_(need_grad) | ||
| grad_to_none = [x, h_pre_c] if need_grad else None | ||
|
|
||
| def fwd(): | ||
| if provider == "liger": | ||
| return liger_mhc_pre(x, h_pre_c) | ||
| return (x.float() * h_pre_c.unsqueeze(-1)).sum(dim=-2) | ||
|
|
||
| def fwd_loss(): | ||
| return fwd().square().mean() | ||
|
|
||
| elif sub_kernel == "post_res": | ||
| with torch.no_grad(): | ||
| _, h_post_c, h_res_c = liger_mhc_coeffs( | ||
| x.detach(), | ||
| phi.detach(), | ||
| b_param.detach(), | ||
| alpha_pre.detach(), | ||
| alpha_post.detach(), | ||
| alpha_res.detach(), | ||
| **coeffs_cfg, | ||
| ) | ||
| h_post_c.requires_grad_(need_grad) | ||
| h_res_c.requires_grad_(need_grad) | ||
| f_out = torch.randn(B, T, C, device=device, dtype=torch.bfloat16, requires_grad=need_grad) | ||
| grad_to_none = [x, f_out, h_post_c, h_res_c] if need_grad else None | ||
|
|
||
| def fwd(): | ||
| if provider == "liger": | ||
| return liger_mhc_post_res(x, f_out, h_post_c, h_res_c) | ||
| return torch.einsum("...oi,...ic->...oc", h_res_c, x.float()) + h_post_c.unsqueeze( | ||
| -1 | ||
| ) * f_out.float().unsqueeze(-2) | ||
|
|
||
| def fwd_loss(): | ||
| return fwd().square().mean() | ||
|
|
||
| if mode == "forward": | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) | ||
| elif mode == "backward": | ||
| y = fwd_loss() | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
| lambda: y.backward(retain_graph=True), | ||
| grad_to_none=grad_to_none, | ||
| rep=100, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| elif mode == "full": | ||
|
|
||
| def full(): | ||
| y = fwd_loss() | ||
| y.backward() | ||
|
|
||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=grad_to_none, rep=100, quantiles=QUANTILES) | ||
|
|
||
| return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) | ||
|
|
||
|
|
||
| def bench_memory_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | ||
| from test.transformers.test_mhc import mhc_coeffs_ref | ||
|
|
||
| T = input.x | ||
| B = input.extra_benchmark_config["B"] | ||
| HC = input.extra_benchmark_config["HC"] | ||
| C = input.extra_benchmark_config["C"] | ||
| sub_kernel = input.extra_benchmark_config["sub_kernel"] | ||
| tmax = input.extra_benchmark_config["tmax"] | ||
| rms_eps = input.extra_benchmark_config["rms_eps"] | ||
| pre_eps = input.extra_benchmark_config["pre_eps"] | ||
| sinkhorn_eps = input.extra_benchmark_config["sinkhorn_eps"] | ||
| post_mult = input.extra_benchmark_config["post_mult"] | ||
| provider = input.kernel_provider | ||
|
|
||
| coeffs_cfg = dict(tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult) | ||
|
|
||
| x = torch.randn(B, T, HC, C, device=device, dtype=torch.bfloat16, requires_grad=True) | ||
| K, M = HC * C, HC * HC + 2 * HC | ||
| phi = (torch.randn(K, M, device=device, dtype=torch.bfloat16) * 0.02).requires_grad_(True) | ||
| b_param = torch.zeros(M, device=device, dtype=torch.float32, requires_grad=True) | ||
| alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) | ||
| alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) | ||
| alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) | ||
|
|
||
| if sub_kernel == "coeffs": | ||
|
|
||
| def full(): | ||
| if provider == "liger": | ||
| hp, hpo, hr = liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg) | ||
| else: | ||
| hp, hpo, hr = mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg) | ||
| (hp.square().mean() + hpo.square().mean() + hr.square().mean()).backward() | ||
|
|
||
| elif sub_kernel == "pre": | ||
| with torch.no_grad(): | ||
| h_pre_c, _, _ = liger_mhc_coeffs( | ||
| x.detach(), | ||
| phi.detach(), | ||
| b_param.detach(), | ||
| alpha_pre.detach(), | ||
| alpha_post.detach(), | ||
| alpha_res.detach(), | ||
| **coeffs_cfg, | ||
| ) | ||
| h_pre_c.requires_grad_(True) | ||
|
|
||
| def full(): | ||
| if provider == "liger": | ||
| out = liger_mhc_pre(x, h_pre_c) | ||
| else: | ||
| out = (x.float() * h_pre_c.unsqueeze(-1)).sum(dim=-2) | ||
| out.square().mean().backward() | ||
|
|
||
| elif sub_kernel == "post_res": | ||
| with torch.no_grad(): | ||
| _, h_post_c, h_res_c = liger_mhc_coeffs( | ||
| x.detach(), | ||
| phi.detach(), | ||
| b_param.detach(), | ||
| alpha_pre.detach(), | ||
| alpha_post.detach(), | ||
| alpha_res.detach(), | ||
| **coeffs_cfg, | ||
| ) | ||
| h_post_c.requires_grad_(True) | ||
| h_res_c.requires_grad_(True) | ||
| f_out = torch.randn(B, T, C, device=device, dtype=torch.bfloat16, requires_grad=True) | ||
|
|
||
| def full(): | ||
| if provider == "liger": | ||
| out = liger_mhc_post_res(x, f_out, h_post_c, h_res_c) | ||
| else: | ||
| out = torch.einsum("...oi,...ic->...oc", h_res_c, x.float()) + h_post_c.unsqueeze( | ||
| -1 | ||
| ) * f_out.float().unsqueeze(-2) | ||
| out.square().mean().backward() | ||
|
|
||
| mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) | ||
| return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = parse_benchmark_script_args() | ||
|
|
||
| for sub_kernel in ["coeffs", "pre", "post_res"]: | ||
| common_configs = { | ||
| "kernel_name": f"mhc_{sub_kernel}", | ||
| "x_name": "T", | ||
| "x_label": "Sequence Length (T)", | ||
| "x_values": [2**i for i in range(7, 12)], | ||
| "kernel_providers": ["liger", "torch"], | ||
| "extra_benchmark_configs": [ | ||
| { | ||
| "B": 4, | ||
| "HC": 4, | ||
| "C": 4096, | ||
| "tmax": 20, | ||
| "rms_eps": 1e-6, | ||
| "pre_eps": 0.0, | ||
| "sinkhorn_eps": 1e-6, | ||
| "post_mult": 2.0, | ||
| "sub_kernel": sub_kernel, | ||
| } | ||
| ], | ||
| "overwrite": args.overwrite, | ||
| } | ||
|
|
||
| run_benchmarks( | ||
| bench_test_fn=bench_speed_mhc, | ||
| kernel_operation_modes=["forward", "backward", "full"], | ||
| metric_name="speed", | ||
| metric_unit="ms", | ||
| **common_configs, | ||
| ) | ||
|
|
||
| run_benchmarks( | ||
| bench_test_fn=bench_memory_mhc, | ||
| kernel_operation_modes=["full"], | ||
| metric_name="memory", | ||
| metric_unit="MB", | ||
| **common_configs, | ||
| ) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.