@@ -338,7 +338,7 @@ def test_apply_liger_kernel_to_instance_for_llama():
338338 with patch ("transformers.models.llama.modeling_llama" ):
339339 # Instantiate a dummy model
340340 config = transformers .models .llama .configuration_llama .LlamaConfig (
341- torch_dtype = torch .bfloat16 ,
341+ dtype = torch .bfloat16 ,
342342 rms_norm_eps = 1e-5 ,
343343 hidden_size = 32 ,
344344 intermediate_size = 64 ,
@@ -382,7 +382,7 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation():
382382
383383 # Instantiate a dummy model
384384 config = transformers .models .mllama .configuration_mllama .MllamaConfig (
385- torch_dtype = torch .bfloat16 ,
385+ dtype = torch .bfloat16 ,
386386 text_config = transformers .models .mllama .configuration_mllama .MllamaTextConfig (
387387 rms_norm_eps = 1e-5 ,
388388 hidden_size = 32 ,
@@ -533,7 +533,7 @@ def test_apply_liger_kernel_to_instance_for_llama4_for_causal_lm():
533533
534534 # Instantiate a dummy model
535535 config = transformers .models .llama4 .configuration_llama4 .Llama4TextConfig (
536- torch_dtype = torch .bfloat16 ,
536+ dtype = torch .bfloat16 ,
537537 rms_norm_eps = 1e-5 ,
538538 hidden_size = 32 ,
539539 intermediate_size = 64 ,
@@ -573,9 +573,9 @@ def test_apply_liger_kernel_to_instance_for_llama4_for_conditional_generation():
573573
574574 # Instantiate a dummy model
575575 config = transformers .models .llama4 .configuration_llama4 .Llama4Config (
576- torch_dtype = torch .bfloat16 ,
576+ dtype = torch .bfloat16 ,
577577 text_config = transformers .models .llama4 .configuration_llama4 .Llama4TextConfig (
578- torch_dtype = torch .bfloat16 ,
578+ dtype = torch .bfloat16 ,
579579 rms_norm_eps = 1e-5 ,
580580 hidden_size = 32 ,
581581 intermediate_size = 64 ,
@@ -656,7 +656,7 @@ def test_apply_liger_kernel_to_instance_for_mistral():
656656 with patch ("transformers.models.mistral.modeling_mistral" ):
657657 # Instantiate a dummy model
658658 config = transformers .models .mistral .configuration_mistral .MistralConfig (
659- torch_dtype = torch .bfloat16 ,
659+ dtype = torch .bfloat16 ,
660660 rms_norm_eps = 1e-5 ,
661661 hidden_size = 32 ,
662662 intermediate_size = 64 ,
@@ -695,7 +695,7 @@ def test_apply_liger_kernel_to_instance_for_mixtral():
695695 with patch ("transformers.models.mixtral.modeling_mixtral" ):
696696 # Instantiate a dummy model
697697 config = transformers .models .mixtral .configuration_mixtral .MixtralConfig (
698- torch_dtype = torch .bfloat16 ,
698+ dtype = torch .bfloat16 ,
699699 rms_norm_eps = 1e-5 ,
700700 hidden_size = 32 ,
701701 intermediate_size = 64 ,
@@ -738,7 +738,7 @@ def test_apply_liger_kernel_to_instance_for_gemma():
738738 with patch ("transformers.models.gemma.modeling_gemma" ):
739739 # Instantiate a dummy model
740740 config = transformers .models .gemma .configuration_gemma .GemmaConfig (
741- torch_dtype = torch .bfloat16 ,
741+ dtype = torch .bfloat16 ,
742742 rms_norm_eps = 1e-5 ,
743743 hidden_size = 32 ,
744744 intermediate_size = 64 ,
@@ -777,7 +777,7 @@ def test_apply_liger_kernel_to_instance_for_gemma2():
777777 with patch ("transformers.models.gemma2.modeling_gemma2" ):
778778 # Instantiate a dummy model
779779 config = transformers .models .gemma2 .configuration_gemma2 .Gemma2Config (
780- torch_dtype = torch .bfloat16 ,
780+ dtype = torch .bfloat16 ,
781781 rms_norm_eps = 1e-5 ,
782782 hidden_size = 32 ,
783783 intermediate_size = 64 ,
@@ -827,7 +827,7 @@ def test_apply_liger_kernel_to_instance_for_paligemma():
827827
828828 # Instantiate a dummy model
829829 config = transformers .models .paligemma .configuration_paligemma .PaliGemmaConfig (
830- torch_dtype = torch .bfloat16 ,
830+ dtype = torch .bfloat16 ,
831831 text_config = {
832832 "num_hidden_layers" : 2 ,
833833 "rms_norm_eps" : 1e-5 ,
@@ -883,7 +883,7 @@ def test_apply_liger_kernel_to_instance_for_gemma3_text():
883883
884884 # Instantiate a dummy model
885885 config = transformers .models .gemma3 .configuration_gemma3 .Gemma3TextConfig (
886- torch_dtype = torch .bfloat16 ,
886+ dtype = torch .bfloat16 ,
887887 rms_norm_eps = 1e-5 ,
888888 hidden_size = 32 ,
889889 intermediate_size = 64 ,
@@ -939,7 +939,7 @@ def test_apply_liger_kernel_to_instance_for_gemma3_conditional_generation():
939939
940940 # Instantiate a dummy model
941941 text_config = transformers .models .gemma3 .configuration_gemma3 .Gemma3TextConfig (
942- torch_dtype = torch .bfloat16 ,
942+ dtype = torch .bfloat16 ,
943943 rms_norm_eps = 1e-5 ,
944944 hidden_size = 32 ,
945945 intermediate_size = 64 ,
@@ -1026,7 +1026,7 @@ def test_apply_liger_kernel_to_instance_for_qwen2():
10261026 with patch ("transformers.models.qwen2.modeling_qwen2" ):
10271027 # Instantiate a dummy model
10281028 config = transformers .models .qwen2 .configuration_qwen2 .Qwen2Config (
1029- torch_dtype = torch .bfloat16 ,
1029+ dtype = torch .bfloat16 ,
10301030 rms_norm_eps = 1e-5 ,
10311031 hidden_size = 32 ,
10321032 intermediate_size = 64 ,
@@ -1068,7 +1068,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3():
10681068
10691069 # Instantiate a dummy model
10701070 config = transformers .models .qwen3 .configuration_qwen3 .Qwen3Config (
1071- torch_dtype = torch .bfloat16 ,
1071+ dtype = torch .bfloat16 ,
10721072 rms_norm_eps = 1e-5 ,
10731073 hidden_size = 32 ,
10741074 intermediate_size = 64 ,
@@ -1110,7 +1110,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_moe():
11101110
11111111 # Instantiate a dummy model
11121112 config = transformers .models .qwen3_moe .configuration_qwen3_moe .Qwen3MoeConfig (
1113- torch_dtype = torch .bfloat16 ,
1113+ dtype = torch .bfloat16 ,
11141114 rms_norm_eps = 1e-5 ,
11151115 hidden_size = 32 ,
11161116 intermediate_size = 64 ,
@@ -1158,7 +1158,7 @@ def test_apply_liger_kernel_to_instance_for_qwen2_vl_for_conditional_generation(
11581158
11591159 # Instantiate a dummy model
11601160 config = transformers .models .qwen2_vl .configuration_qwen2_vl .Qwen2VLConfig (
1161- torch_dtype = torch .bfloat16 ,
1161+ dtype = torch .bfloat16 ,
11621162 rms_norm_eps = 1e-5 ,
11631163 hidden_size = 32 ,
11641164 intermediate_size = 48 ,
@@ -1227,7 +1227,7 @@ def test_apply_liger_kernel_to_instance_for_qwen2_vl():
12271227
12281228 # Instantiate a dummy model
12291229 config = transformers .models .qwen2_vl .configuration_qwen2_vl .Qwen2VLConfig (
1230- torch_dtype = torch .bfloat16 ,
1230+ dtype = torch .bfloat16 ,
12311231 rms_norm_eps = 1e-5 ,
12321232 hidden_size = 32 ,
12331233 intermediate_size = 48 ,
@@ -1294,7 +1294,7 @@ def test_apply_liger_kernel_to_instance_for_qwen2_vl_text():
12941294
12951295 # Instantiate a dummy model
12961296 config = transformers .models .qwen2_vl .configuration_qwen2_vl .Qwen2VLTextConfig (
1297- torch_dtype = torch .bfloat16 ,
1297+ dtype = torch .bfloat16 ,
12981298 rms_norm_eps = 1e-5 ,
12991299 hidden_size = 32 ,
13001300 intermediate_size = 48 ,
@@ -1347,7 +1347,7 @@ def test_apply_liger_kernel_to_instance_for_qwen2_5_vl():
13471347
13481348 # Instantiate a dummy model
13491349 config = transformers .models .qwen2_5_vl .configuration_qwen2_5_vl .Qwen2_5_VLConfig (
1350- torch_dtype = torch .bfloat16 ,
1350+ dtype = torch .bfloat16 ,
13511351 rms_norm_eps = 1e-5 ,
13521352 hidden_size = 32 ,
13531353 intermediate_size = 48 ,
@@ -1416,7 +1416,7 @@ def test_apply_liger_kernel_to_instance_for_qwen2_5_vl_for_conditional_generatio
14161416
14171417 # Instantiate a dummy model
14181418 config = transformers .models .qwen2_5_vl .configuration_qwen2_5_vl .Qwen2_5_VLConfig (
1419- torch_dtype = torch .bfloat16 ,
1419+ dtype = torch .bfloat16 ,
14201420 rms_norm_eps = 1e-5 ,
14211421 hidden_size = 32 ,
14221422 intermediate_size = 48 ,
@@ -1483,7 +1483,7 @@ def test_apply_liger_kernel_to_instance_for_qwen2_5_vl_text():
14831483
14841484 # Instantiate a dummy model
14851485 config = transformers .models .qwen2_5_vl .configuration_qwen2_5_vl .Qwen2_5_VLTextConfig (
1486- torch_dtype = torch .bfloat16 ,
1486+ dtype = torch .bfloat16 ,
14871487 rms_norm_eps = 1e-5 ,
14881488 hidden_size = 32 ,
14891489 intermediate_size = 48 ,
@@ -1528,7 +1528,7 @@ def test_apply_liger_kernel_to_instance_for_phi3():
15281528 with patch ("transformers.models.phi3.modeling_phi3" ):
15291529 # Instantiate a dummy model
15301530 config = transformers .models .phi3 .configuration_phi3 .Phi3Config (
1531- torch_dtype = torch .bfloat16 ,
1531+ dtype = torch .bfloat16 ,
15321532 rms_norm_eps = 1e-5 ,
15331533 hidden_size = 32 ,
15341534 intermediate_size = 64 ,
@@ -1570,7 +1570,7 @@ def test_apply_liger_kernel_to_instance_for_olmo2():
15701570
15711571 # Instantiate a dummy model
15721572 config = transformers .models .olmo2 .configuration_olmo2 .Olmo2Config (
1573- torch_dtype = torch .bfloat16 ,
1573+ dtype = torch .bfloat16 ,
15741574 rms_norm_eps = 1e-5 ,
15751575 hidden_size = 32 ,
15761576 intermediate_size = 64 ,
@@ -1616,7 +1616,7 @@ def test_apply_liger_kernel_to_instance_for_glm4():
16161616
16171617 # Instantiate a dummy model
16181618 config = transformers .models .glm4 .configuration_glm4 .Glm4Config (
1619- torch_dtype = torch .bfloat16 ,
1619+ dtype = torch .bfloat16 ,
16201620 rms_norm_eps = 1e-5 ,
16211621 hidden_size = 32 ,
16221622 intermediate_size = 64 ,
@@ -1664,7 +1664,7 @@ def test_apply_liger_kernel_to_instance_for_glm4v():
16641664
16651665 # Instantiate a dummy model
16661666 config = transformers .models .glm4v .configuration_glm4v .Glm4vConfig (
1667- torch_dtype = torch .bfloat16 ,
1667+ dtype = torch .bfloat16 ,
16681668 text_config = {
16691669 "num_hidden_layers" : 2 ,
16701670 "rms_norm_eps" : 1e-5 ,
@@ -1734,7 +1734,7 @@ def test_apply_liger_kernel_to_instance_for_glm4v_moe():
17341734
17351735 # Instantiate a dummy model
17361736 config = transformers .models .glm4v_moe .configuration_glm4v_moe .Glm4vMoeConfig (
1737- torch_dtype = torch .bfloat16 ,
1737+ dtype = torch .bfloat16 ,
17381738 hidden_size = 32 ,
17391739 num_attention_heads = 4 ,
17401740 num_key_value_heads = 2 ,
@@ -1837,7 +1837,7 @@ def test_apply_liger_kernel_to_instance_for_smollm3():
18371837 with patch ("transformers.models.smollm3.modeling_smollm3" ):
18381838 # Instantiate a dummy model
18391839 config = transformers .models .smollm3 .configuration_smollm3 .SmolLM3Config (
1840- torch_dtype = torch .bfloat16 ,
1840+ dtype = torch .bfloat16 ,
18411841 rms_norm_eps = 1e-5 ,
18421842 hidden_size = 32 ,
18431843 intermediate_size = 64 ,
0 commit comments