Skip to content

Commit 7656bcb

Browse files
authored
Fix memory requirements for benchmark_jsd.py and benchmark_distill_jsd_loss.py (#1050)
## Summary In #1037 I mixed `benchmark_jsd.py` and `benchmark_distill_jsd_loss.py`. It is `benchmark_distill_jsd_loss.py` that requires 39GBs of memory: https://github.com/linkedin/Liger-Kernel/blob/0ea0b8ffcee27c5c94ffa87e480ea95036a0d2da/benchmark/data/all_benchmark_data.csv#L746 69GBs of memory: https://github.com/linkedin/Liger-Kernel/blob/0ea0b8ffcee27c5c94ffa87e480ea95036a0d2da/benchmark/data/all_benchmark_data.csv#L747 `benchmark_jsd.py` requires just 53GBs: https://github.com/linkedin/Liger-Kernel/blob/0ea0b8ffcee27c5c94ffa87e480ea95036a0d2da/benchmark/data/all_benchmark_data.csv#L459 ## Testing Done I run both benchmarks on XPU Intel GPU Max 1100 with 48 GBs of memory - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent 0ea0b8f commit 7656bcb

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

benchmark/scripts/benchmark_distill_jsd_loss.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from utils import run_benchmarks
1313

1414
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
15+
from liger_kernel.utils import get_total_gpu_memory
1516
from liger_kernel.utils import infer_device
1617

1718
device = infer_device()
@@ -224,12 +225,20 @@ def full():
224225

225226
if __name__ == "__main__":
226227
args = parse_benchmark_script_args()
228+
gpu_memory_gbs = get_total_gpu_memory()
229+
# We know that the full test will require 69GBs for vocab size 2^13 and 39GBs for vocab size 2^12 on torch
230+
if gpu_memory_gbs >= 69:
231+
x_max = 13
232+
elif gpu_memory_gbs >= 39:
233+
x_max = 12
234+
else:
235+
x_max = 11
227236

228237
common_configs = {
229238
"kernel_name": "distill_jsd_loss",
230239
"x_name": "BT",
231240
"x_label": "B x T",
232-
"x_values": [2**i for i in range(10, 14)],
241+
"x_values": [2**i for i in range(10, x_max + 1)],
233242
"kernel_providers": ["liger", "torch"],
234243
"extra_benchmark_configs": [
235244
{

benchmark/scripts/benchmark_jsd.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,11 @@ def full():
125125
if __name__ == "__main__":
126126
args = parse_benchmark_script_args()
127127
gpu_memory_gbs = get_total_gpu_memory()
128-
# We know that the full test will require 69GBs for vocab size 2^17 and 39GBs for vocab size 2^16 on torch
129-
if gpu_memory_gbs >= 69:
128+
# We know that the full test will require 54GBs for vocab size 2^17 on torch
129+
if gpu_memory_gbs >= 54:
130130
x_max = 17
131-
elif gpu_memory_gbs >= 39:
132-
x_max = 16
133131
else:
134-
x_max = 15
132+
x_max = 16
135133
common_args = {
136134
"kernel_name": "jsd",
137135
"x_name": "V",

0 commit comments

Comments
 (0)