Skip to content

Commit 66570b1

Browse files
Update transformers version in tests-bwd and optional-deps, increase timeout, and fix tests-bwd (#794)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Update tests-bwd and optional-deps to use transformers version 4.49.0 instead of 4.44.2. Also, increase the timeout for tests. For this change, the following modifications are also needed: - Mllama monkey_patch fix - Perform llava convergence tests for `transformers>=4.52.0` as we don't materialize logits for earlier versions - Increase test tolerance for `granite` - Run `qwen2_vl` and `qwen2_5_vl` tests only for `transformers>=4.52.4` <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: H100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent e808281 commit 66570b1

File tree

7 files changed

+49
-44
lines changed

7 files changed

+49
-44
lines changed

dev/modal/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
repo = image.add_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
1515

1616

17-
@app.function(gpu="A10G", image=repo, timeout=60 * 45)
17+
@app.function(gpu="A10G", image=repo, timeout=60 * 60)
1818
def liger_tests():
1919
import subprocess
2020

dev/modal/tests_bwd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
repo = image.add_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
1515

1616

17-
@app.function(gpu="A10G", image=repo, timeout=60 * 30)
17+
@app.function(gpu="A10G", image=repo, timeout=60 * 60)
1818
def liger_bwd_tests():
1919
import subprocess
2020

@@ -24,9 +24,9 @@ def liger_bwd_tests():
2424
shell=True,
2525
cwd=REMOTE_ROOT_PATH,
2626
)
27-
# force install transformers==4.44.2
27+
# force install transformers==4.49.0
2828
subprocess.run(
29-
["uv pip install transformers==4.44.2 --system"],
29+
["uv pip install transformers==4.49.0 --system"],
3030
check=True,
3131
shell=True,
3232
cwd=REMOTE_ROOT_PATH,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_optional_dependencies():
3131
"""Get optional dependency groups."""
3232
return {
3333
"dev": [
34-
"transformers>=4.44.2",
34+
"transformers>=4.49.0",
3535
"matplotlib>=3.7.2",
3636
"flake8>=4.0.1.1",
3737
"black>=24.4.2",

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,10 @@ def apply_liger_kernel_to_mllama(
537537
if isinstance(model, MllamaForConditionalGeneration):
538538
language_model: MllamaForCausalLM = model.language_model
539539
vision_model: MllamaVisionModel = model.vision_model
540-
text_model: MllamaTextModel = language_model
540+
if isinstance(language_model, MllamaForCausalLM):
541+
text_model: MllamaTextModel = language_model.model
542+
else:
543+
text_model = language_model
541544
elif isinstance(model, MllamaForCausalLM):
542545
text_model = model.model
543546
vision_model = None

test/convergence/bf16/test_mini_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -957,8 +957,8 @@ def run_mini_model(
957957
reason="LLaVa not available in this version of transformers",
958958
),
959959
pytest.mark.skipif(
960-
version.parse(transformers.__version__) < version.parse("4.49.0"),
961-
reason="Mistral not available in transformers<=4.49.0",
960+
version.parse(transformers.__version__) < version.parse("4.52.0"),
961+
reason="LLaVa doesn't materialize logits in transformers<=4.52.0 so we can't test it",
962962
),
963963
],
964964
),

test/convergence/fp32/test_mini_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -938,8 +938,8 @@ def run_mini_model(
938938
reason="LLaVa not available in this version of transformers",
939939
),
940940
pytest.mark.skipif(
941-
version.parse(transformers.__version__) < version.parse("4.49.0"),
942-
reason="Mistral not available in transformers<=4.49.0",
941+
version.parse(transformers.__version__) < version.parse("4.52.0"),
942+
reason="LLaVa doesn't materialize logits in transformers<=4.52.0 so we can't test it",
943943
),
944944
],
945945
),
@@ -1103,7 +1103,7 @@ def run_mini_model(
11031103
torch.float32,
11041104
1e-8,
11051105
1e-4,
1106-
5e-3, # 4e-3
1106+
4e-2, # 4e-3
11071107
1e-5, # 1e-5
11081108
5e-3,
11091109
1e-5,

test/transformers/test_monkey_patch.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -74,24 +74,6 @@ def is_llama4_available():
7474
return False
7575

7676

77-
def is_qwen2_vl_available():
78-
try:
79-
import transformers.models.qwen2_vl # noqa: F401
80-
81-
return True
82-
except ImportError:
83-
return False
84-
85-
86-
def is_qwen2_5_vl_available():
87-
try:
88-
import transformers.models.qwen2_5_vl # noqa: F401
89-
90-
return True
91-
except ImportError:
92-
return False
93-
94-
9577
def is_qwen3_available():
9678
try:
9779
import transformers.models.qwen3 # noqa: F401
@@ -365,6 +347,7 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation():
365347
# Ensure any monkey patching is cleaned up for subsequent tests
366348
with patch("transformers.models.mllama.modeling_mllama"):
367349
from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration
350+
from transformers.models.mllama.modeling_mllama import MllamaTextModel
368351

369352
# Instantiate a dummy model
370353
config = transformers.models.mllama.configuration_mllama.MllamaConfig(
@@ -398,10 +381,14 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation():
398381

399382
# Check that model instance variables are not yet patched with Liger modules
400383
assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(mllama_lce_forward)
401-
assert inspect.getsource(dummy_model_instance.language_model.norm.forward) != inspect.getsource(
402-
LigerRMSNorm.forward
403-
)
404-
for layer in dummy_model_instance.language_model.layers:
384+
385+
if isinstance(dummy_model_instance.language_model, MllamaTextModel):
386+
language_model = dummy_model_instance.language_model
387+
else:
388+
language_model = dummy_model_instance.language_model.model
389+
390+
assert inspect.getsource(language_model.norm.forward) != inspect.getsource(LigerRMSNorm.forward)
391+
for layer in language_model.layers:
405392
assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward)
406393
assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward)
407394
assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward)
@@ -428,10 +415,8 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation():
428415

429416
# Check that the model's instance variables were correctly patched with Liger modules
430417
assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(mllama_lce_forward)
431-
assert inspect.getsource(dummy_model_instance.language_model.norm.forward) == inspect.getsource(
432-
LigerRMSNorm.forward
433-
)
434-
for layer in dummy_model_instance.language_model.layers:
418+
assert inspect.getsource(language_model.norm.forward) == inspect.getsource(LigerRMSNorm.forward)
419+
for layer in language_model.layers:
435420
assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward)
436421
assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward)
437422
assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward)
@@ -452,7 +437,6 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation():
452437
assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(
453438
LigerLayerNorm.forward
454439
)
455-
456440
try:
457441
print(dummy_model_instance)
458442
except Exception as e:
@@ -1130,7 +1114,10 @@ def test_apply_liger_kernel_to_instance_for_qwen3_moe():
11301114
pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
11311115

11321116

1133-
@pytest.mark.skipif(not is_qwen2_vl_available(), reason="qwen2_vl module not available")
1117+
@pytest.mark.skipif(
1118+
transformer_version < version.parse("4.52.4"),
1119+
reason="Qwen2-VL support is only compatible with transformers >= 4.52.4",
1120+
)
11341121
def test_apply_liger_kernel_to_instance_for_qwen2_vl_for_conditional_generation():
11351122
# Ensure any monkey patching is cleaned up for subsequent tests
11361123
with patch("transformers.models.qwen2_vl.modeling_qwen2_vl"):
@@ -1196,7 +1183,10 @@ def test_apply_liger_kernel_to_instance_for_qwen2_vl_for_conditional_generation(
11961183
pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
11971184

11981185

1199-
@pytest.mark.skipif(not is_qwen2_vl_available(), reason="qwen2_vl module not available")
1186+
@pytest.mark.skipif(
1187+
transformer_version < version.parse("4.52.4"),
1188+
reason="Qwen2-VL support is only compatible with transformers >= 4.52.4",
1189+
)
12001190
def test_apply_liger_kernel_to_instance_for_qwen2_vl():
12011191
# Ensure any monkey patching is cleaned up for subsequent tests
12021192
with patch("transformers.models.qwen2_vl.modeling_qwen2_vl"):
@@ -1262,7 +1252,10 @@ def test_apply_liger_kernel_to_instance_for_qwen2_vl():
12621252
pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
12631253

12641254

1265-
@pytest.mark.skipif(not is_qwen2_vl_available(), reason="qwen2_vl module not available")
1255+
@pytest.mark.skipif(
1256+
transformer_version < version.parse("4.52.4"),
1257+
reason="Qwen2-VL support is only compatible with transformers >= 4.52.4",
1258+
)
12661259
def test_apply_liger_kernel_to_instance_for_qwen2_vl_text():
12671260
# Ensure any monkey patching is cleaned up for subsequent tests
12681261
with patch("transformers.models.qwen2_vl.modeling_qwen2_vl"):
@@ -1310,7 +1303,10 @@ def test_apply_liger_kernel_to_instance_for_qwen2_vl_text():
13101303
pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
13111304

13121305

1313-
@pytest.mark.skipif(not is_qwen2_5_vl_available(), reason="qwen2_5_vl module not available")
1306+
@pytest.mark.skipif(
1307+
transformer_version < version.parse("4.52.4"),
1308+
reason="Qwen2.5-VL support is only compatible with transformers >= 4.52.4",
1309+
)
13141310
def test_apply_liger_kernel_to_instance_for_qwen2_5_vl():
13151311
# Ensure any monkey patching is cleaned up for subsequent tests
13161312
with patch("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl"):
@@ -1376,7 +1372,10 @@ def test_apply_liger_kernel_to_instance_for_qwen2_5_vl():
13761372
pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
13771373

13781374

1379-
@pytest.mark.skipif(not is_qwen2_5_vl_available(), reason="qwen2_5_vl module not available")
1375+
@pytest.mark.skipif(
1376+
transformer_version < version.parse("4.52.4"),
1377+
reason="Qwen2.5-VL support is only compatible with transformers >= 4.52.4",
1378+
)
13801379
def test_apply_liger_kernel_to_instance_for_qwen2_5_vl_for_conditional_generation():
13811380
# Ensure any monkey patching is cleaned up for subsequent tests
13821381
with patch("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl"):
@@ -1442,7 +1441,10 @@ def test_apply_liger_kernel_to_instance_for_qwen2_5_vl_for_conditional_generatio
14421441
pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
14431442

14441443

1445-
@pytest.mark.skipif(not is_qwen2_5_vl_available(), reason="qwen2_5_vl module not available")
1444+
@pytest.mark.skipif(
1445+
transformer_version < version.parse("4.52.4"),
1446+
reason="Qwen2.5-VL support is only compatible with transformers >= 4.52.4",
1447+
)
14461448
def test_apply_liger_kernel_to_instance_for_qwen2_5_vl_text():
14471449
# Ensure any monkey patching is cleaned up for subsequent tests
14481450
with patch("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl"):

0 commit comments

Comments
 (0)