Skip to content

Commit ab07bfd

Browse files
authored
Workaround for OOM error on benchmark_jsd (#1037)
## Summary The benchmark currently requires a lot of memory with current configuration (69GBs), it's the heaviest of them all based on current results from `all_benchmark_data.csv`, so I added a workaround for GPUs with not enough memory. ## Details Alternative implementation could be to replace all benchmark function calls with another function that would process OOM errors, but that would require changing all benchmarks. We would need to replace all `triton.testing.do_bench` and with local function that would handle OOM errors and change `_test_memory` as well. And we would probably need to start saving some `inf` results in the csv. ## Testing Done Tested specifically changed benchmark. - Hardware Type: Intel GPU Max 1550 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent 7b51e56 commit ab07bfd

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

benchmark/scripts/benchmark_jsd.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from utils import run_benchmarks
1010

1111
from liger_kernel.transformers.jsd import LigerJSD
12+
from liger_kernel.utils import get_total_gpu_memory
1213
from liger_kernel.utils import infer_device
1314

1415
device = infer_device()
@@ -123,11 +124,19 @@ def full():
123124

124125
if __name__ == "__main__":
125126
args = parse_benchmark_script_args()
127+
gpu_memory_gbs = get_total_gpu_memory()
128+
# We know that the full test will require 69GBs for vocab size 2^17 and 39GBs for vocab size 2^16 on torch
129+
if gpu_memory_gbs >= 69:
130+
x_max = 17
131+
elif gpu_memory_gbs >= 39:
132+
x_max = 16
133+
else:
134+
x_max = 15
126135
common_args = {
127136
"kernel_name": "jsd",
128137
"x_name": "V",
129138
"x_label": "vocab size",
130-
"x_values": [2**i for i in range(12, 18)],
139+
"x_values": [2**i for i in range(12, x_max + 1)],
131140
"kernel_providers": ["liger", "torch"],
132141
"extra_benchmark_configs": [{"B": 4, "T": 2048}],
133142
"overwrite": args.overwrite,

src/liger_kernel/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,16 @@ def transformers_version_dispatch(
121121
return before_fn(*before_args, **before_kwargs)
122122
else:
123123
return after_fn(*after_args, **after_kwargs)
124+
125+
126+
def get_total_gpu_memory() -> int:
127+
"""Returns total GPU memory in GBs."""
128+
device = infer_device()
129+
if device == "cuda":
130+
return torch.cuda.get_device_properties(0).total_memory // (1024**3)
131+
elif device == "xpu":
132+
return torch.xpu.get_device_properties(0).total_memory // (1024**3)
133+
elif device == "npu":
134+
return torch.npu.get_device_properties(0).total_memory // (1024**3)
135+
else:
136+
raise RuntimeError(f"Unsupported device: {device}")

0 commit comments

Comments
 (0)