Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions dev/modal/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

@app.function(gpu="H100", image=repo, timeout=60 * 45)
def liger_benchmarks():
import subprocess
import os
import subprocess

subprocess.run(
["uv pip install -e '.[dev]' --system"],
Expand All @@ -31,7 +31,7 @@ def liger_benchmarks():
file_path = Path(REMOTE_ROOT_PATH) / "benchmark" / "data" / "all_benchmark_data.csv"
print(f"Checking if file exists at: {file_path}")
print(f"File exists: {os.path.exists(file_path)}")

if not os.path.exists(file_path):
print("Listing directory contents:")
data_dir = file_path.parent
Expand All @@ -54,21 +54,21 @@ def main():
# Run the benchmarks and get the data
print("Starting benchmark run...")
benchmark_data = liger_benchmarks.remote()

if not benchmark_data:
raise ValueError("No data received from remote function")

# Save the data locally
local_data_path = ROOT_PATH / "benchmark" / "data" / "all_benchmark_data.csv"
print(f"Attempting to save data to: {local_data_path}")

local_data_path.parent.mkdir(parents=True, exist_ok=True)

with open(local_data_path, "wb") as f:
f.write(benchmark_data)

print(f"Successfully saved {len(benchmark_data)} bytes to: {local_data_path}")

except Exception as e:
print(f"Error occurred: {str(e)}")
raise
38 changes: 20 additions & 18 deletions test/convergence/bf16/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from test.utils import DEFAULT_DATASET_PATH
from test.utils import MiniModelConfig
from test.utils import assert_verbose_allclose
from test.utils import get_logprobs
from test.utils import get_topk
from test.utils import revert_liger_kernel_to_gemma
from test.utils import revert_liger_kernel_to_gemma2
from test.utils import revert_liger_kernel_to_gemma3_text
Expand Down Expand Up @@ -851,17 +853,17 @@ def run_mini_model(
eval_output = model(**eval_batch)
print(f"Eval Loss: {eval_output.loss.item()}")
loss_list.append(eval_output.loss.item())

topk_logprobs = get_topk(get_logprobs(eval_output.logits))
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs)
return {
"loss": loss_list,
"logits": eval_output.logits,
"topk_logprobs": topk_logprobs.values,
"model": model,
}


@pytest.mark.parametrize(
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol",
[
pytest.param(
"mini_llama3",
Expand All @@ -884,7 +886,7 @@ def run_mini_model(
1e-3,
1e-2,
1e-1,
1e-2,
1e-1,
1e-2,
1e-2,
marks=[
Expand All @@ -902,7 +904,7 @@ def run_mini_model(
torch.bfloat16,
1e-3,
1e-2,
1, # 1e-1
1e-1, # 1e-1
1e-1, # 1e-2
Comment on lines +907 to 908
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing all logprobs comparison, we can try setting it lower.
sglang only has atol and sets it to 5e-2 (decode_tolerance)
verl sets (atol, rtol) = (1e-2, 1e-5), but it's mean of all logprobs not topk

Copy link
Contributor Author

@Manan17 Manan17 Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does not work with lower tolerance.
For gemma3, it passes when atol=1e-1 and rtol=1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this out with fp32, it fails for most of the models where old logic for checking the logits is passing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are comparing values in log-space, the total tolerance here is actually relative tolerance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just check the rtol?
like: tolerance = rtol * torch.abs(tensor2)

Copy link
Collaborator

@Tcc0403 Tcc0403 Jun 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

absolute diff for two logprobs (logA - logB) = relative diff for two probs (A / B), which means the whole tolerance (atol + rtol * torch.abs(expected)) should be the maximum relative diff we can accept.

I think that's also why sglang only has a single tolerance in their test.

1e-2,
1e-2,
Expand Down Expand Up @@ -972,7 +974,7 @@ def run_mini_model(
torch.bfloat16,
1e-3,
1e-2,
1, # 1e-1
1e-1, # 1e-1
1e-1, # 1e-2
1e-2,
1e-2,
Expand Down Expand Up @@ -1111,8 +1113,8 @@ def run_mini_model(
torch.bfloat16,
1e-3,
1e-2,
1e-1,
1e-2,
1e-1,
1e-2,
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
Expand All @@ -1124,8 +1126,8 @@ def run_mini_model(
torch.bfloat16,
1e-3,
1e-2,
1e-1,
1e-2,
1e-1,
1e-2,
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
Expand Down Expand Up @@ -1153,8 +1155,8 @@ def run_mini_model(
torch.bfloat16,
1e-3,
1e-2,
1e-1,
1e-2,
3e-1,
4e-1,
1e-2,
1e-2,
marks=[
Expand All @@ -1174,8 +1176,8 @@ def test_mini_model(
dtype,
loss_atol,
loss_rtol,
logits_atol,
logits_rtol,
logprobs_atol,
logprobs_rtol,
param_atol,
param_rtol,
):
Expand All @@ -1193,13 +1195,13 @@ def test_mini_model(
rtol=loss_rtol,
)

# Compare the logits from evaluation step
if expected_output["logits"] is not None and actual_output["logits"] is not None:
# Compare the topk logprobs from evaluation step
if expected_output["topk_logprobs"] is not None and actual_output["topk_logprobs"] is not None:
assert_verbose_allclose(
expected_output["logits"],
actual_output["logits"],
atol=logits_atol,
rtol=logits_rtol,
expected_output["topk_logprobs"],
actual_output["topk_logprobs"],
atol=logprobs_atol,
rtol=logprobs_rtol,
)

# Compare the params from the last step
Expand Down
26 changes: 16 additions & 10 deletions test/convergence/bf16/test_mini_models_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from test.utils import UNTOKENIZED_DATASET_PATH
from test.utils import MiniModelConfig
from test.utils import assert_verbose_allclose
from test.utils import get_logprobs
from test.utils import get_topk
from test.utils import is_torchvision_available
from test.utils import load_image_processing_config
from test.utils import load_processor_config
Expand Down Expand Up @@ -764,13 +766,17 @@ def run_mini_model_multimodal(

print(f"Step {i}, Loss: {output.loss.item()}")
loss_list.append(output.loss.item())

topk_logprobs = get_topk(get_logprobs(output.logits))
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs)
return {"loss": loss_list, "logits": output.logits, "model": model}
return {
"loss": loss_list,
"topk_logprobs": topk_logprobs.values,
"model": model,
}


@pytest.mark.parametrize(
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol",
[
pytest.param(
"mini_qwen2_vl",
Expand Down Expand Up @@ -917,8 +923,8 @@ def test_mini_model_multimodal(
dtype,
loss_atol,
loss_rtol,
logits_atol,
logits_rtol,
logprobs_atol,
logprobs_rtol,
param_atol,
param_rtol,
):
Expand All @@ -937,12 +943,12 @@ def test_mini_model_multimodal(
rtol=loss_rtol,
)

# Compare the logits from the last step
# Compare the topk logprobs from evaluation step
assert_verbose_allclose(
expected_output["logits"],
actual_output["logits"],
atol=logits_atol,
rtol=logits_rtol,
expected_output["topk_logprobs"],
actual_output["topk_logprobs"],
atol=logprobs_atol,
rtol=logprobs_rtol,
)

# Compare the params from the last step
Expand Down
29 changes: 18 additions & 11 deletions test/convergence/bf16/test_mini_models_with_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from test.utils import DEFAULT_DATASET_PATH
from test.utils import MiniModelConfig
from test.utils import assert_verbose_allclose
from test.utils import get_logprobs
from test.utils import get_topk
from test.utils import revert_liger_kernel_to_gemma
from test.utils import revert_liger_kernel_to_gemma2
from test.utils import revert_liger_kernel_to_gemma3_text
Expand Down Expand Up @@ -842,12 +844,17 @@ def run_mini_model(
print(f"Step {i}, Loss: {output.loss.item()}")
loss_list.append(output.loss.item())

topk_logprobs = get_topk(get_logprobs(output.logits))
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs)
return {"loss": loss_list, "logits": output.logits, "model": model}
return {
"loss": loss_list,
"topk_logprobs": topk_logprobs.values,
"model": model,
}


@pytest.mark.parametrize(
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol",
[
pytest.param(
"mini_llama3",
Expand Down Expand Up @@ -1058,8 +1065,8 @@ def run_mini_model(
torch.bfloat16,
1e-3,
1e-2,
1e-1,
1e-2,
1e-1,
1e-2,
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
Expand All @@ -1071,8 +1078,8 @@ def run_mini_model(
torch.bfloat16,
1e-3,
1e-2,
1e-1,
1e-2,
1e-1,
1e-2,
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
Expand Down Expand Up @@ -1159,8 +1166,8 @@ def test_mini_model(
dtype,
loss_atol,
loss_rtol,
logits_atol,
logits_rtol,
logprobs_atol,
logprobs_rtol,
param_atol,
param_rtol,
):
Expand All @@ -1180,12 +1187,12 @@ def test_mini_model(

# No logits are materialized
# import pdb; pdb.set_trace()
# Compare the logits from the last step
# Compare the topk logprobs from evaluation step
assert_verbose_allclose(
expected_output["logits"],
actual_output["logits"],
atol=logits_atol,
rtol=logits_rtol,
expected_output["topk_logprobs"],
actual_output["topk_logprobs"],
atol=logprobs_atol,
rtol=logprobs_rtol,
)

# Compare the params from the last step
Expand Down
26 changes: 14 additions & 12 deletions test/convergence/fp32/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from test.utils import DEFAULT_DATASET_PATH
from test.utils import MiniModelConfig
from test.utils import assert_verbose_allclose
from test.utils import get_logprobs
from test.utils import get_topk
from test.utils import revert_liger_kernel_to_gemma
from test.utils import revert_liger_kernel_to_gemma2
from test.utils import revert_liger_kernel_to_gemma3_text
Expand Down Expand Up @@ -849,17 +851,17 @@ def run_mini_model(
eval_output = model(**eval_batch)
print(f"Eval Loss: {eval_output.loss.item()}")
loss_list.append(eval_output.loss.item())

topk_logprobs = get_topk(get_logprobs(eval_output.logits))
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs)
return {
"loss": loss_list,
"logits": eval_output.logits,
"topk_logprobs": topk_logprobs.values,
"model": model,
}


@pytest.mark.parametrize(
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol",
[
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
pytest.param(
Expand Down Expand Up @@ -1013,7 +1015,7 @@ def run_mini_model(
# TODO: mixtral is flaky so disable the test for now
# ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
# Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-2, 5e-3, 1e-5),
("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
pytest.param(
Expand Down Expand Up @@ -1041,8 +1043,8 @@ def test_mini_model(
dtype,
loss_atol,
loss_rtol,
logits_atol,
logits_rtol,
logprobs_atol,
logprobs_rtol,
param_atol,
param_rtol,
):
Expand All @@ -1060,13 +1062,13 @@ def test_mini_model(
rtol=loss_rtol,
)

# Compare the logits from evaluation step
if expected_output["logits"] is not None and actual_output["logits"] is not None:
# Compare the topk logprobs from evaluation step
if expected_output["topk_logprobs"] is not None and actual_output["topk_logprobs"] is not None:
assert_verbose_allclose(
expected_output["logits"],
actual_output["logits"],
atol=logits_atol,
rtol=logits_rtol,
expected_output["topk_logprobs"],
actual_output["topk_logprobs"],
atol=logprobs_atol,
rtol=logprobs_rtol,
)

# Compare the params from the last step
Expand Down
Loading
Loading