Skip to content

Commit 9ee7014

Browse files
Merge branch 'main' into kolehma8/dist_swiglu
2 parents d7fbcc9 + 5841280 commit 9ee7014

File tree

14 files changed

+3467
-42
lines changed

14 files changed

+3467
-42
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ loss.backward()
293293
| Multi Token Attention | `liger_kernel.transformers.LigerMultiTokenAttention` |
294294
| Softmax | `liger_kernel.transformers.LigerSoftmax` |
295295
| Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
296+
| mHC (Hyper-Connections) | `liger_kernel.transformers.LigerMHC` |
296297

297298

298299
### Alignment Kernels

benchmark/data/all_benchmark_data.csv

Lines changed: 48 additions & 36 deletions
Large diffs are not rendered by default.

benchmark/scripts/benchmark_mhc.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
import os
2+
import sys
3+
4+
import torch
5+
import triton
6+
7+
from utils import QUANTILES
8+
from utils import SingleBenchmarkRunInput
9+
from utils import SingleBenchmarkRunOutput
10+
from utils import _test_memory
11+
from utils import parse_benchmark_script_args
12+
from utils import run_benchmarks
13+
14+
from liger_kernel.transformers.functional import liger_mhc_coeffs
15+
from liger_kernel.transformers.functional import liger_mhc_post_res
16+
from liger_kernel.transformers.functional import liger_mhc_pre
17+
from liger_kernel.utils import infer_device
18+
19+
device = infer_device()
20+
21+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
22+
23+
24+
def bench_speed_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
25+
from test.transformers.test_mhc import mhc_coeffs_ref
26+
27+
T = input.x
28+
B = input.extra_benchmark_config["B"]
29+
HC = input.extra_benchmark_config["HC"]
30+
C = input.extra_benchmark_config["C"]
31+
sub_kernel = input.extra_benchmark_config["sub_kernel"]
32+
tmax = input.extra_benchmark_config["tmax"]
33+
rms_eps = input.extra_benchmark_config["rms_eps"]
34+
pre_eps = input.extra_benchmark_config["pre_eps"]
35+
sinkhorn_eps = input.extra_benchmark_config["sinkhorn_eps"]
36+
post_mult = input.extra_benchmark_config["post_mult"]
37+
provider = input.kernel_provider
38+
mode = input.kernel_operation_mode
39+
40+
coeffs_cfg = dict(tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult)
41+
need_grad = mode in ("backward", "full")
42+
43+
x = torch.randn(B, T, HC, C, device=device, dtype=torch.bfloat16, requires_grad=need_grad)
44+
K, M = HC * C, HC * HC + 2 * HC
45+
phi = (torch.randn(K, M, device=device, dtype=torch.bfloat16) * 0.02).requires_grad_(need_grad)
46+
b_param = torch.zeros(M, device=device, dtype=torch.float32, requires_grad=need_grad)
47+
alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad)
48+
alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad)
49+
alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad)
50+
51+
grad_to_none = [x, phi, b_param, alpha_pre, alpha_post, alpha_res] if need_grad else None
52+
53+
if sub_kernel == "coeffs":
54+
55+
def fwd():
56+
if provider == "liger":
57+
return liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg)
58+
return mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg)
59+
60+
def fwd_loss():
61+
h_pre, h_post, h_res = fwd()
62+
return h_pre.square().mean() + h_post.square().mean() + h_res.square().mean()
63+
64+
elif sub_kernel == "pre":
65+
with torch.no_grad():
66+
h_pre_c, _, _ = liger_mhc_coeffs(
67+
x.detach(),
68+
phi.detach(),
69+
b_param.detach(),
70+
alpha_pre.detach(),
71+
alpha_post.detach(),
72+
alpha_res.detach(),
73+
**coeffs_cfg,
74+
)
75+
h_pre_c.requires_grad_(need_grad)
76+
grad_to_none = [x, h_pre_c] if need_grad else None
77+
78+
def fwd():
79+
if provider == "liger":
80+
return liger_mhc_pre(x, h_pre_c)
81+
return (x.float() * h_pre_c.unsqueeze(-1)).sum(dim=-2)
82+
83+
def fwd_loss():
84+
return fwd().square().mean()
85+
86+
elif sub_kernel == "post_res":
87+
with torch.no_grad():
88+
_, h_post_c, h_res_c = liger_mhc_coeffs(
89+
x.detach(),
90+
phi.detach(),
91+
b_param.detach(),
92+
alpha_pre.detach(),
93+
alpha_post.detach(),
94+
alpha_res.detach(),
95+
**coeffs_cfg,
96+
)
97+
h_post_c.requires_grad_(need_grad)
98+
h_res_c.requires_grad_(need_grad)
99+
f_out = torch.randn(B, T, C, device=device, dtype=torch.bfloat16, requires_grad=need_grad)
100+
grad_to_none = [x, f_out, h_post_c, h_res_c] if need_grad else None
101+
102+
def fwd():
103+
if provider == "liger":
104+
return liger_mhc_post_res(x, f_out, h_post_c, h_res_c)
105+
return torch.einsum("...oi,...ic->...oc", h_res_c, x.float()) + h_post_c.unsqueeze(
106+
-1
107+
) * f_out.float().unsqueeze(-2)
108+
109+
def fwd_loss():
110+
return fwd().square().mean()
111+
112+
if mode == "forward":
113+
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
114+
elif mode == "backward":
115+
y = fwd_loss()
116+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
117+
lambda: y.backward(retain_graph=True),
118+
grad_to_none=grad_to_none,
119+
rep=100,
120+
quantiles=QUANTILES,
121+
)
122+
elif mode == "full":
123+
124+
def full():
125+
y = fwd_loss()
126+
y.backward()
127+
128+
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=grad_to_none, rep=100, quantiles=QUANTILES)
129+
130+
return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80)
131+
132+
133+
def bench_memory_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
134+
from test.transformers.test_mhc import mhc_coeffs_ref
135+
136+
T = input.x
137+
B = input.extra_benchmark_config["B"]
138+
HC = input.extra_benchmark_config["HC"]
139+
C = input.extra_benchmark_config["C"]
140+
sub_kernel = input.extra_benchmark_config["sub_kernel"]
141+
tmax = input.extra_benchmark_config["tmax"]
142+
rms_eps = input.extra_benchmark_config["rms_eps"]
143+
pre_eps = input.extra_benchmark_config["pre_eps"]
144+
sinkhorn_eps = input.extra_benchmark_config["sinkhorn_eps"]
145+
post_mult = input.extra_benchmark_config["post_mult"]
146+
provider = input.kernel_provider
147+
148+
coeffs_cfg = dict(tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult)
149+
150+
x = torch.randn(B, T, HC, C, device=device, dtype=torch.bfloat16, requires_grad=True)
151+
K, M = HC * C, HC * HC + 2 * HC
152+
phi = (torch.randn(K, M, device=device, dtype=torch.bfloat16) * 0.02).requires_grad_(True)
153+
b_param = torch.zeros(M, device=device, dtype=torch.float32, requires_grad=True)
154+
alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True)
155+
alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True)
156+
alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True)
157+
158+
if sub_kernel == "coeffs":
159+
160+
def full():
161+
if provider == "liger":
162+
hp, hpo, hr = liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg)
163+
else:
164+
hp, hpo, hr = mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg)
165+
(hp.square().mean() + hpo.square().mean() + hr.square().mean()).backward()
166+
167+
elif sub_kernel == "pre":
168+
with torch.no_grad():
169+
h_pre_c, _, _ = liger_mhc_coeffs(
170+
x.detach(),
171+
phi.detach(),
172+
b_param.detach(),
173+
alpha_pre.detach(),
174+
alpha_post.detach(),
175+
alpha_res.detach(),
176+
**coeffs_cfg,
177+
)
178+
h_pre_c.requires_grad_(True)
179+
180+
def full():
181+
if provider == "liger":
182+
out = liger_mhc_pre(x, h_pre_c)
183+
else:
184+
out = (x.float() * h_pre_c.unsqueeze(-1)).sum(dim=-2)
185+
out.square().mean().backward()
186+
187+
elif sub_kernel == "post_res":
188+
with torch.no_grad():
189+
_, h_post_c, h_res_c = liger_mhc_coeffs(
190+
x.detach(),
191+
phi.detach(),
192+
b_param.detach(),
193+
alpha_pre.detach(),
194+
alpha_post.detach(),
195+
alpha_res.detach(),
196+
**coeffs_cfg,
197+
)
198+
h_post_c.requires_grad_(True)
199+
h_res_c.requires_grad_(True)
200+
f_out = torch.randn(B, T, C, device=device, dtype=torch.bfloat16, requires_grad=True)
201+
202+
def full():
203+
if provider == "liger":
204+
out = liger_mhc_post_res(x, f_out, h_post_c, h_res_c)
205+
else:
206+
out = torch.einsum("...oi,...ic->...oc", h_res_c, x.float()) + h_post_c.unsqueeze(
207+
-1
208+
) * f_out.float().unsqueeze(-2)
209+
out.square().mean().backward()
210+
211+
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
212+
return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80)
213+
214+
215+
if __name__ == "__main__":
216+
args = parse_benchmark_script_args()
217+
218+
for sub_kernel in ["coeffs", "pre", "post_res"]:
219+
common_configs = {
220+
"kernel_name": f"mhc_{sub_kernel}",
221+
"x_name": "T",
222+
"x_label": "Sequence Length (T)",
223+
"x_values": [2**i for i in range(7, 12)],
224+
"kernel_providers": ["liger", "torch"],
225+
"extra_benchmark_configs": [
226+
{
227+
"B": 4,
228+
"HC": 4,
229+
"C": 4096,
230+
"tmax": 20,
231+
"rms_eps": 1e-6,
232+
"pre_eps": 0.0,
233+
"sinkhorn_eps": 1e-6,
234+
"post_mult": 2.0,
235+
"sub_kernel": sub_kernel,
236+
}
237+
],
238+
"overwrite": args.overwrite,
239+
}
240+
241+
run_benchmarks(
242+
bench_test_fn=bench_speed_mhc,
243+
kernel_operation_modes=["forward", "backward", "full"],
244+
metric_name="speed",
245+
metric_unit="ms",
246+
**common_configs,
247+
)
248+
249+
run_benchmarks(
250+
bench_test_fn=bench_memory_mhc,
251+
kernel_operation_modes=["full"],
252+
metric_name="memory",
253+
metric_unit="MB",
254+
**common_configs,
255+
)

0 commit comments

Comments
 (0)