Skip to content

Commit 74589e6

Browse files
authored
Fix Backward Compatibility for Convergence Test (#1078)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This PR restores backward compatibility for convergence tests with transformers v4 (v4.49.0 ~ v4.57.6). During the initial development phase for transformers v5 support, backward compatibility was intentionally deprioritized, leading to significant test regressions. This PR fixes those regressions while maintaining a stable foundation for the ongoing v5 integration. ## Related Issues & PRs - #978 - #994 ## Details The current codebase assumes transformers v5 conventions, which broke compatibility with the v4.x series in two major areas: 1. RoPE Parameters: Some model miss some rope parameters (`rope_scaling`) since they are unified to `rope_parameters` in transformer v5. 2. Tokenizer Consistency: v5 and v4 handle the Tokenizer interfaces differently. V5's Tokenizer will select the appropriate backend, while v4's Tokenizer is the python-based implementation using SentencePiece as backend. Key Fixes: - Added conditional logic to provide different rope parameters for different transformers versions. - Enforced TokenizerFast usage for transformers < v5 to resolve interface mismatches. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> I ran `python -m pytest test/convergence/*` on different versions of transformers on the original branch and after making changes. The result is shown below: | Branches| v4.49.0 | v4.57.6 | v5.0.0 | |---|---|---|---| | transformer-5.0.0rc1 | 8 failed, 37 passed, 98 skipped, 1 warning | 42 failed, 92 passed, 9 skipped, 3 warnings| 19 failed, 115 passed, 9 skipped, 29 warnings | | This PR | 0 failed, 45 passed, 98 skipped, 1 warning | 0 failed, 134 passed, 9 skipped, 19 warnings | 19 failed, 115 passed, 9 skipped, 29 warnings | All of the failed tests in v5 are inspected carefully that all of them are identical to the previously thrown error. </div></b> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: H100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent e6fdedc commit 74589e6

File tree

6 files changed

+270
-47
lines changed

6 files changed

+270
-47
lines changed

test/convergence/bf16/test_mini_models.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
import pytest
66
import torch
7+
import transformers
78

89
from datasets import load_from_disk
10+
from packaging import version
911
from torch.utils.data import DataLoader
1012
from transformers.models.gemma import GemmaConfig
1113
from transformers.models.gemma import GemmaForCausalLM
@@ -53,6 +55,7 @@
5355
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl
5456
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl_moe
5557
from liger_kernel.transformers import apply_liger_kernel_to_smollm3
58+
from liger_kernel.utils import infer_device
5659
from test.utils import DEFAULT_DATASET_PATH
5760
from test.utils import MiniModelConfig
5861
from test.utils import assert_verbose_allclose
@@ -94,6 +97,8 @@
9497
from test.utils import simple_collate_fn
9598
from test.utils import supports_bfloat16
9699

100+
IS_TRANSFORMERS_V5_OR_LATER = version.parse(transformers.__version__) >= version.parse("5.0.0")
101+
97102
try:
98103
from transformers.models.llama4.configuration_llama4 import Llama4TextConfig
99104
from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
@@ -307,7 +312,6 @@
307312
except ImportError:
308313
EXAONE4_AVAILABLE = False
309314

310-
from liger_kernel.utils import infer_device
311315

312316
device = infer_device()
313317

@@ -702,6 +706,16 @@
702706
use_cache=True,
703707
vocab_size=32000, # 128256,
704708
attn_implementation="sdpa", # default value, pytorch native attention
709+
rope_scaling=dict(
710+
factor=8.0,
711+
high_freq_factor=4.0,
712+
low_freq_factor=1.0,
713+
original_max_position_embeddings=8192,
714+
rope_type="llama3",
715+
rope_theta=500_000,
716+
)
717+
if not IS_TRANSFORMERS_V5_OR_LATER
718+
else None,
705719
),
706720
)
707721

@@ -728,9 +742,11 @@
728742
"num_hidden_layers": 4, # 80
729743
"num_key_value_heads": 2, # 8
730744
"rms_norm_eps": 1e-6, # 1e-5
731-
"rope_parameters": {
732-
"mrope_section": [16, 24, 24], # (temporal, height, width)
733-
},
745+
**(
746+
{"rope_parameters": {"mrope_section": [16, 24, 24]}} # (temporal, height, width)
747+
if IS_TRANSFORMERS_V5_OR_LATER
748+
else {"rope_scaling": {"type": "mrope", "mrope_section": [16, 24, 24]}}
749+
),
734750
"sliding_window": 4096,
735751
"tie_word_embeddings": False,
736752
"use_cache": True,
@@ -779,9 +795,11 @@
779795
"num_hidden_layers": 4, # 80
780796
"num_key_value_heads": 2, # 8
781797
"rms_norm_eps": 1e-6, # 1e-5
782-
"rope_parameters": {
783-
"mrope_section": [16, 24, 24], # (temporal, height, width)
784-
},
798+
**(
799+
{"rope_parameters": {"mrope_section": [16, 24, 24]}} # (temporal, height, width)
800+
if IS_TRANSFORMERS_V5_OR_LATER
801+
else {"rope_scaling": {"type": "mrope", "mrope_section": [16, 24, 24]}}
802+
),
785803
"sliding_window": 4096,
786804
"tie_word_embeddings": False,
787805
"use_cache": True,
@@ -839,6 +857,12 @@
839857
rms_norm_eps=1e-6,
840858
use_cache=True,
841859
vocab_size=32768,
860+
rope_scaling=dict(
861+
type="mrope",
862+
mrope_section=[16, 24, 24], # (temporal, height, width)
863+
)
864+
if not IS_TRANSFORMERS_V5_OR_LATER
865+
else None,
842866
),
843867
vision_config=dict(
844868
depth=4,
@@ -893,6 +917,12 @@
893917
num_experts=4,
894918
tie_word_embeddings=False,
895919
mlp_only_layers=[],
920+
rope_scaling=dict(
921+
type="mrope",
922+
mrope_section=[16, 24, 24], # (temporal, height, width)
923+
)
924+
if not IS_TRANSFORMERS_V5_OR_LATER
925+
else None,
896926
).to_dict(),
897927
vision_config=Qwen3VLMoeVisionConfig(
898928
depth=4,
@@ -1129,6 +1159,11 @@
11291159
"rms_norm_eps": 1e-5,
11301160
"vocab_size": 32000,
11311161
"attention_bias": True,
1162+
**(
1163+
{"rope_scaling": {"type": "default", "mrope_section": [8, 12, 12]}}
1164+
if not IS_TRANSFORMERS_V5_OR_LATER
1165+
else {}
1166+
),
11321167
},
11331168
vision_config={
11341169
"depth": 4, # 32
@@ -1199,6 +1234,11 @@
11991234
"topk_group": 1,
12001235
"first_k_dense_replace": 1,
12011236
"norm_topk_prob": True,
1237+
**(
1238+
{"rope_scaling": {"type": "default", "mrope_section": [8, 12, 12]}}
1239+
if not IS_TRANSFORMERS_V5_OR_LATER
1240+
else {}
1241+
),
12021242
},
12031243
vision_config={
12041244
"depth": 4, # 32

test/convergence/bf16/test_mini_models_multimodal.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Ensure deterministic behavior with CuBLAS
55
import pytest
66
import torch
7+
import transformers
78

89
from datasets import load_dataset
10+
from packaging import version
911
from torch.utils.data import DataLoader
1012
from transformers import PreTrainedTokenizerFast
11-
from transformers.models.gemma.tokenization_gemma import GemmaTokenizer
1213
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
1314

1415
from liger_kernel.transformers import apply_liger_kernel_to_gemma3
@@ -22,6 +23,7 @@
2223
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl
2324
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl_moe
2425
from liger_kernel.transformers import apply_liger_kernel_to_smolvlm
26+
from liger_kernel.utils import infer_device
2527
from test.utils import FAKE_CONFIGS_PATH
2628
from test.utils import UNTOKENIZED_DATASET_PATH
2729
from test.utils import MiniModelConfig
@@ -49,12 +51,23 @@
4951
from test.utils import supports_bfloat16
5052
from test.utils import train_bpe_tokenizer
5153

54+
IS_TRANSFORMERS_V5_OR_LATER = version.parse(transformers.__version__) >= version.parse("5.0.0")
55+
56+
if IS_TRANSFORMERS_V5_OR_LATER:
57+
from transformers.models.gemma.tokenization_gemma import GemmaTokenizer
58+
else:
59+
from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast as GemmaTokenizer
60+
5261
try:
5362
# Qwen2-VL is only available in transformers>=4.52.4
5463
import transformers
5564

5665
from packaging import version
57-
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
66+
67+
if IS_TRANSFORMERS_V5_OR_LATER:
68+
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
69+
else:
70+
from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast as Qwen2Tokenizer
5871
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig
5972
from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor
6073
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
@@ -70,7 +83,11 @@
7083
import transformers
7184

7285
from packaging import version
73-
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
86+
87+
if IS_TRANSFORMERS_V5_OR_LATER:
88+
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
89+
else:
90+
from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast as Qwen2Tokenizer
7491
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig
7592
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
7693
from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessor
@@ -82,7 +99,10 @@
8299
QWEN2_5_VL_AVAILABLE = False
83100

84101
try:
85-
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
102+
if IS_TRANSFORMERS_V5_OR_LATER:
103+
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
104+
else:
105+
from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast as Qwen2Tokenizer
86106
from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor
87107
from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig
88108
from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLTextConfig
@@ -138,7 +158,6 @@
138158

139159
from packaging import version
140160
from transformers.models.gemma.configuration_gemma import GemmaConfig
141-
from transformers.models.gemma.tokenization_gemma import GemmaTokenizer
142161
from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
143162
from transformers.models.paligemma.configuration_paligemma import PaliGemmaConfig
144163
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
@@ -209,7 +228,6 @@
209228
except ImportError:
210229
NUM2WORDS_AVAILABLE = False
211230

212-
from liger_kernel.utils import infer_device
213231

214232
device = infer_device()
215233

@@ -314,6 +332,15 @@
314332
num_hidden_layers=4, # 40
315333
num_key_value_heads=2, # 8
316334
rms_norm_eps=1e-5,
335+
rope_scaling=dict(
336+
factor=8.0,
337+
high_freq_factor=4.0,
338+
low_freq_factor=1.0,
339+
original_max_position_embeddings=8192,
340+
rope_type="llama3",
341+
)
342+
if not IS_TRANSFORMERS_V5_OR_LATER
343+
else None,
317344
tie_word_embeddings=False,
318345
use_cache=True,
319346
vocab_size=32000, # 128256,
@@ -491,8 +518,10 @@
491518
num_hidden_layers=4, # 80
492519
num_key_value_heads=2, # 8
493520
rms_norm_eps=1e-6, # 1e-5
494-
rope_parameters=dict(
495-
mrope_section=[16, 24, 24], # (temporal, height, width)
521+
**(
522+
dict(rope_parameters=dict(mrope_section=[16, 24, 24])) # (temporal, height, width)
523+
if IS_TRANSFORMERS_V5_OR_LATER
524+
else dict(rope_scaling=dict(type="mrope", mrope_section=[16, 24, 24]))
496525
),
497526
sliding_window=4096,
498527
tie_word_embeddings=True,
@@ -663,8 +692,10 @@
663692
num_hidden_layers=4, # 80
664693
num_key_value_heads=2, # 8
665694
rms_norm_eps=1e-6, # 1e-5
666-
rope_parameters=dict(
667-
mrope_section=[16, 24, 24], # (temporal, height, width)
695+
**(
696+
dict(rope_parameters=dict(mrope_section=[16, 24, 24])) # (temporal, height, width)
697+
if IS_TRANSFORMERS_V5_OR_LATER
698+
else dict(rope_scaling=dict(type="mrope", mrope_section=[16, 24, 24]))
668699
),
669700
sliding_window=4096,
670701
tie_word_embeddings=True,
@@ -723,6 +754,12 @@
723754
rms_norm_eps=1e-6,
724755
use_cache=False,
725756
tie_word_embeddings=True,
757+
rope_scaling=dict(
758+
type="mrope",
759+
mrope_section=[16, 24, 24], # (temporal, height, width)
760+
)
761+
if not IS_TRANSFORMERS_V5_OR_LATER
762+
else None,
726763
attention_dropout=0.0,
727764
attention_bias=False,
728765
).to_dict(),
@@ -770,6 +807,12 @@
770807
rms_norm_eps=1e-6,
771808
use_cache=False,
772809
tie_word_embeddings=True,
810+
rope_scaling=dict(
811+
type="mrope",
812+
mrope_section=[16, 24, 24], # (temporal, height, width)
813+
)
814+
if not IS_TRANSFORMERS_V5_OR_LATER
815+
else None,
773816
attention_dropout=0.0,
774817
attention_bias=False,
775818
decoder_sparse_step=1,

0 commit comments

Comments
 (0)