Skip to content

Commit 1ea9175

Browse files
authored
fix(benchmark): move chunked loss module init out of measurements (#643)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Fix #789 ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com>
1 parent b27b4c5 commit 1ea9175

File tree

4 files changed

+56
-40
lines changed

4 files changed

+56
-40
lines changed

benchmark/scripts/benchmark_cpo_loss.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,20 @@ def bench_memory_fused_linear_cpo_loss(
3636
dtype = input.extra_benchmark_config["dtype"]
3737
provider = input.kernel_provider
3838

39-
torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
40-
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
39+
# Instantiate once and retrieve the first output only
40+
torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
41+
liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
42+
torch_fwd = lambda x, target: torch_lm_head_cpo(x, target)[0]
43+
liger_fwd = lambda x, target: liger_lm_head_cpo(x, target)[0]
4144

4245
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
4346
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
4447

4548
def fwd():
4649
if provider == "liger":
47-
return liger_lm_head_cpo(_input, target)
50+
return liger_fwd(_input, target)
4851
elif provider == "huggingface":
49-
return torch_lm_head_cpo(_input, target)
52+
return torch_fwd(_input, target)
5053

5154
def full():
5255
y = fwd()
@@ -79,17 +82,20 @@ def bench_speed_fused_linear_cpo_loss(
7982
provider = input.kernel_provider
8083
mode = input.kernel_operation_mode
8184

82-
torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
83-
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
85+
# Instantiate once and retrieve the first output only
86+
torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
87+
liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
88+
torch_fwd = lambda x, target: torch_lm_head_cpo(x, target)[0]
89+
liger_fwd = lambda x, target: liger_lm_head_cpo(x, target)[0]
8490

8591
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
8692
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
8793

8894
def fwd():
8995
if provider == "liger":
90-
return liger_lm_head_cpo(_input, target)
96+
return liger_fwd(_input, target)
9197
elif provider == "huggingface":
92-
return torch_lm_head_cpo(_input, target)
98+
return torch_fwd(_input, target)
9399

94100
if mode == "forward":
95101
ms_50, ms_20, ms_80 = triton.testing.do_bench(

benchmark/scripts/benchmark_dpo_loss.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,11 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
3232
ignore_index = input.extra_benchmark_config["ignore_index"]
3333
provider = input.kernel_provider
3434

35-
torch_dpo_loss = lambda x, ref_x, target: TorchLMHeadDPO(
36-
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
37-
).to(device)(x, ref_x, target)[0]
38-
liger_dpo_loss = lambda x, ref_x, target: LigerLMHeadDPO(
39-
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
40-
).to(device)(x, ref_x, target)[0]
35+
# Instantiate once and retrieve the first output only
36+
torch_dpo_loss = TorchLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
37+
liger_dpo_loss = LigerLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
38+
torch_fwd = lambda x, ref_x, target: torch_dpo_loss(x, ref_x, target)[0]
39+
liger_fwd = lambda x, ref_x, target: liger_dpo_loss(x, ref_x, target)[0]
4140

4241
# Input shape: [B, T, H]
4342
_input = torch.randn(B, T, H, device=device, dtype=dtype)
@@ -52,9 +51,9 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
5251

5352
def fwd():
5453
if provider == "liger":
55-
return liger_dpo_loss(_input, ref_input, target)
54+
return liger_fwd(_input, ref_input, target)
5655
elif provider == "huggingface":
57-
return torch_dpo_loss(_input, ref_input, target)
56+
return torch_fwd(_input, ref_input, target)
5857

5958
def full():
6059
y = fwd()
@@ -83,12 +82,11 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
8382
provider = input.kernel_provider
8483
mode = input.kernel_operation_mode
8584

86-
torch_dpo_loss = lambda x, ref_x, target: TorchLMHeadDPO(
87-
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
88-
).to(device)(x, ref_x, target)[0]
89-
liger_dpo_loss = lambda x, ref_x, target: LigerLMHeadDPO(
90-
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
91-
).to(device)(x, ref_x, target)[0]
85+
# Instantiate once and retrieve the first output only
86+
torch_dpo_loss = TorchLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
87+
liger_dpo_loss = LigerLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
88+
torch_fwd = lambda x, ref_x, target: torch_dpo_loss(x, ref_x, target)[0]
89+
liger_fwd = lambda x, ref_x, target: liger_dpo_loss(x, ref_x, target)[0]
9290

9391
# Input shape: [B, T, H]
9492
_input = torch.randn(B, T, H, device=device, dtype=dtype)
@@ -103,9 +101,9 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
103101

104102
def fwd():
105103
if provider == "liger":
106-
return liger_dpo_loss(_input, ref_input, target)
104+
return liger_fwd(_input, ref_input, target)
107105
elif provider == "huggingface":
108-
return torch_dpo_loss(_input, ref_input, target)
106+
return torch_fwd(_input, ref_input, target)
109107

110108
if mode == "forward":
111109
ms_50, ms_20, ms_80 = triton.testing.do_bench(

benchmark/scripts/benchmark_orpo_loss.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,21 @@ def bench_memory_fused_linear_orpo_loss(
3636
dtype = input.extra_benchmark_config["dtype"]
3737
provider = input.kernel_provider
3838

39-
torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
40-
liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
39+
# Instantiate once and retrieve the first output only
40+
torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
41+
liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
42+
torch_fwd = lambda x, target, nll_target: torch_lm_head_orpo(x, target, nll_target)[0]
43+
liger_fwd = lambda x, target, nll_target: liger_lm_head_orpo(x, target, nll_target)[0]
4144

4245
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
4346
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
4447
nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)
4548

4649
def fwd():
4750
if provider == "liger":
48-
return liger_lm_head_orpo(_input, target, nll_target)
51+
return liger_fwd(_input, target, nll_target)
4952
elif provider == "huggingface":
50-
return torch_lm_head_orpo(_input, target, nll_target)
53+
return torch_fwd(_input, target, nll_target)
5154

5255
def full():
5356
y = fwd()
@@ -80,18 +83,21 @@ def bench_speed_fused_linear_orpo_loss(
8083
provider = input.kernel_provider
8184
mode = input.kernel_operation_mode
8285

83-
torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
84-
liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
86+
# Instantiate once and retrieve the first output only
87+
torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
88+
liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
89+
torch_fwd = lambda x, target, nll_target: torch_lm_head_orpo(x, target, nll_target)[0]
90+
liger_fwd = lambda x, target, nll_target: liger_lm_head_orpo(x, target, nll_target)[0]
8591

8692
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
8793
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
8894
nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)
8995

9096
def fwd():
9197
if provider == "liger":
92-
return liger_lm_head_orpo(_input, target, nll_target)
98+
return liger_fwd(_input, target, nll_target)
9399
elif provider == "huggingface":
94-
return torch_lm_head_orpo(_input, target, nll_target)
100+
return torch_fwd(_input, target, nll_target)
95101

96102
if mode == "forward":
97103
ms_50, ms_20, ms_80 = triton.testing.do_bench(

benchmark/scripts/benchmark_simpo_loss.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,20 @@ def bench_memory_fused_linear_simpo_loss(
3636
dtype = input.extra_benchmark_config["dtype"]
3737
provider = input.kernel_provider
3838

39-
torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
40-
liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
39+
# Instantiate once and retrieve the first output only
40+
torch_lm_head_simpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
41+
liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)
42+
torch_fwd = lambda x, target: torch_lm_head_simpo(x, target)[0]
43+
liger_fwd = lambda x, target: liger_lm_head_simpo(x, target)[0]
4144

4245
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
4346
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
4447

4548
def fwd():
4649
if provider == "liger":
47-
return liger_lm_head_simpo(_input, target)
50+
return liger_fwd(_input, target)
4851
elif provider == "huggingface":
49-
return torch_lm_head_simpo(_input, target)
52+
return torch_fwd(_input, target)
5053

5154
def full():
5255
y = fwd()
@@ -79,17 +82,20 @@ def bench_speed_fused_linear_simpo_loss(
7982
provider = input.kernel_provider
8083
mode = input.kernel_operation_mode
8184

82-
torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
83-
liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
85+
# Instantiate once and retrieve the first output only
86+
torch_lm_head_simpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
87+
liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)
88+
torch_fwd = lambda x, target: torch_lm_head_simpo(x, target)[0]
89+
liger_fwd = lambda x, target: liger_lm_head_simpo(x, target)[0]
8490

8591
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
8692
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
8793

8894
def fwd():
8995
if provider == "liger":
90-
return liger_lm_head_simpo(_input, target)
96+
return liger_fwd(_input, target)
9197
elif provider == "huggingface":
92-
return torch_lm_head_simpo(_input, target)
98+
return torch_fwd(_input, target)
9399

94100
if mode == "forward":
95101
ms_50, ms_20, ms_80 = triton.testing.do_bench(

0 commit comments

Comments
 (0)