Skip to content

Commit 28cf670

Browse files
committed
mhc: align benchmark with standard framework and fix convergence test skipif
- benchmark_mhc.py: pass all config params via extra_benchmark_configs following the DPO benchmark pattern - test_mhc_mini_lm.py: remove redundant torch.cuda.is_available() skipif (supports_bfloat16() already covers this case)
1 parent 94d5da7 commit 28cf670

File tree

2 files changed

+25
-10
lines changed

2 files changed

+25
-10
lines changed

benchmark/scripts/benchmark_mhc.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
2222

23-
COEFFS_CFG = dict(tmax=20, rms_eps=1e-6, pre_eps=0.0, sinkhorn_eps=1e-6, post_mult=2.0)
2423
B = 4
2524

2625

@@ -31,9 +30,15 @@ def bench_speed_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
3130
HC = input.extra_benchmark_config["HC"]
3231
C = input.extra_benchmark_config["C"]
3332
sub_kernel = input.extra_benchmark_config["sub_kernel"]
33+
tmax = input.extra_benchmark_config["tmax"]
34+
rms_eps = input.extra_benchmark_config["rms_eps"]
35+
pre_eps = input.extra_benchmark_config["pre_eps"]
36+
sinkhorn_eps = input.extra_benchmark_config["sinkhorn_eps"]
37+
post_mult = input.extra_benchmark_config["post_mult"]
3438
provider = input.kernel_provider
3539
mode = input.kernel_operation_mode
3640

41+
coeffs_cfg = dict(tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult)
3742
need_grad = mode in ("backward", "full")
3843

3944
x = torch.randn(B, T, HC, C, device=device, dtype=torch.bfloat16, requires_grad=need_grad)
@@ -50,8 +55,8 @@ def bench_speed_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
5055

5156
def fwd():
5257
if provider == "liger":
53-
return liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **COEFFS_CFG)
54-
return mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **COEFFS_CFG)
58+
return liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg)
59+
return mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg)
5560

5661
def fwd_loss():
5762
h_pre, h_post, h_res = fwd()
@@ -66,7 +71,7 @@ def fwd_loss():
6671
alpha_pre.detach(),
6772
alpha_post.detach(),
6873
alpha_res.detach(),
69-
**COEFFS_CFG,
74+
**coeffs_cfg,
7075
)
7176
h_pre_c.requires_grad_(need_grad)
7277
grad_to_none = [x, h_pre_c] if need_grad else None
@@ -88,7 +93,7 @@ def fwd_loss():
8893
alpha_pre.detach(),
8994
alpha_post.detach(),
9095
alpha_res.detach(),
91-
**COEFFS_CFG,
96+
**coeffs_cfg,
9297
)
9398
h_post_c.requires_grad_(need_grad)
9499
h_res_c.requires_grad_(need_grad)
@@ -133,8 +138,15 @@ def bench_memory_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
133138
HC = input.extra_benchmark_config["HC"]
134139
C = input.extra_benchmark_config["C"]
135140
sub_kernel = input.extra_benchmark_config["sub_kernel"]
141+
tmax = input.extra_benchmark_config["tmax"]
142+
rms_eps = input.extra_benchmark_config["rms_eps"]
143+
pre_eps = input.extra_benchmark_config["pre_eps"]
144+
sinkhorn_eps = input.extra_benchmark_config["sinkhorn_eps"]
145+
post_mult = input.extra_benchmark_config["post_mult"]
136146
provider = input.kernel_provider
137147

148+
coeffs_cfg = dict(tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult)
149+
138150
x = torch.randn(B, T, HC, C, device=device, dtype=torch.bfloat16, requires_grad=True)
139151
K, M = HC * C, HC * HC + 2 * HC
140152
phi = (torch.randn(K, M, device=device, dtype=torch.bfloat16) * 0.02).requires_grad_(True)
@@ -147,9 +159,9 @@ def bench_memory_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
147159

148160
def full():
149161
if provider == "liger":
150-
hp, hpo, hr = liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **COEFFS_CFG)
162+
hp, hpo, hr = liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg)
151163
else:
152-
hp, hpo, hr = mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **COEFFS_CFG)
164+
hp, hpo, hr = mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg)
153165
(hp.square().mean() + hpo.square().mean() + hr.square().mean()).backward()
154166

155167
elif sub_kernel == "pre":
@@ -161,7 +173,7 @@ def full():
161173
alpha_pre.detach(),
162174
alpha_post.detach(),
163175
alpha_res.detach(),
164-
**COEFFS_CFG,
176+
**coeffs_cfg,
165177
)
166178
h_pre_c.requires_grad_(True)
167179

@@ -181,7 +193,7 @@ def full():
181193
alpha_pre.detach(),
182194
alpha_post.detach(),
183195
alpha_res.detach(),
184-
**COEFFS_CFG,
196+
**coeffs_cfg,
185197
)
186198
h_post_c.requires_grad_(True)
187199
h_res_c.requires_grad_(True)
@@ -215,6 +227,10 @@ def full():
215227
"HC": 4,
216228
"C": 4096,
217229
"tmax": 20,
230+
"rms_eps": 1e-6,
231+
"pre_eps": 0.0,
232+
"sinkhorn_eps": 1e-6,
233+
"post_mult": 2.0,
218234
"sub_kernel": sub_kernel,
219235
}
220236
],

test/convergence/bf16/test_mhc_mini_lm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
9595
return self.head(x_merge)
9696

9797

98-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
9998
@pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU")
10099
def test_mhc_mini_lm_convergence():
101100
set_seed(0)

0 commit comments

Comments
 (0)