Skip to content

Commit 02d2764

Browse files
authored
Replace all torch_dtype with dtype (#881)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Fix #880 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Tcc0403 <[email protected]>
1 parent 960f85d commit 02d2764

File tree

9 files changed

+36
-36
lines changed

9 files changed

+36
-36
lines changed

docs/Examples.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ from liger_kernel.transformers.trainer import LigerORPOTrainer # noqa: F401
239239

240240
model = AutoModelForCausalLM.from_pretrained(
241241
"meta-llama/Llama-3.2-1B-Instruct",
242-
torch_dtype=torch.bfloat16,
242+
dtype=torch.bfloat16,
243243
)
244244

245245
tokenizer = AutoTokenizer.from_pretrained(

examples/alignment/run_orpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
model = AutoModelForCausalLM.from_pretrained(
1111
"meta-llama/Llama-3.2-1B-Instruct",
12-
torch_dtype=torch.bfloat16,
12+
dtype=torch.bfloat16,
1313
)
1414

1515
tokenizer = AutoTokenizer.from_pretrained(

examples/huggingface/training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def train():
4848
custom_args.model_name,
4949
trust_remote_code=True,
5050
use_cache=False,
51-
torch_dtype=torch.bfloat16,
51+
dtype=torch.bfloat16,
5252
# These args will get passed to the appropriate apply_liger_kernel_to_* function
5353
# to override the default settings
5454
# cross_entropy=True,
@@ -59,7 +59,7 @@ def train():
5959
custom_args.model_name,
6060
trust_remote_code=True,
6161
use_cache=False,
62-
torch_dtype=torch.bfloat16,
62+
dtype=torch.bfloat16,
6363
)
6464

6565
trainer = SFTTrainer(

examples/huggingface/training_multimodal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def construct_model_and_processor(model_name: str, use_liger: bool) -> torch.nn.
5656
model = Qwen2VLForConditionalGeneration.from_pretrained(
5757
pretrained_model_name_or_path=model_name,
5858
use_cache=False,
59-
torch_dtype=torch.bfloat16,
59+
dtype=torch.bfloat16,
6060
low_cpu_mem_usage=True,
6161
attn_implementation="sdpa",
6262
)

examples/medusa/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def _model_loader():
319319
model = model_builder(
320320
model_args.model_name_or_path,
321321
cache_dir=training_args.cache_dir,
322-
torch_dtype=torch.bfloat16,
322+
dtype=torch.bfloat16,
323323
)
324324

325325
# Freeze the base model

src/liger_kernel/transformers/fused_linear_cross_entropy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(
2525
assert reduction in {
2626
"mean",
2727
"sum",
28-
"none",
28+
"none",
2929
}, f"reduction must be 'mean' or 'sum' or 'none'. Got: {reduction}"
3030
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
3131
self.ce_weight = ce_weight

src/liger_kernel/transformers/model/glm4v.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def lce_forward(
7070
>>> processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True)
7171
>>> model = Glm4vForConditionalGeneration.from_pretrained(
7272
pretrained_model_name_or_path=MODEL_PATH,
73-
torch_dtype=torch.bfloat16,
73+
dtype=torch.bfloat16,
7474
device_map="auto",
7575
)
7676
>>> inputs = processor.apply_chat_template(

src/liger_kernel/transformers/model/glm4v_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def lce_forward(
7575
>>> processor = AutoProcessor.from_pretrained(MODEL_PATH)
7676
>>> model = Glm4vMoeForConditionalGeneration.from_pretrained(
7777
pretrained_model_name_or_path=MODEL_PATH,
78-
torch_dtype="auto",
78+
dtype="auto",
7979
device_map="auto",
8080
)
8181
>>> inputs = processor.apply_chat_template(

test/transformers/test_monkey_patch.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def test_apply_liger_kernel_to_instance_for_llama():
338338
with patch("transformers.models.llama.modeling_llama"):
339339
# Instantiate a dummy model
340340
config = transformers.models.llama.configuration_llama.LlamaConfig(
341-
torch_dtype=torch.bfloat16,
341+
dtype=torch.bfloat16,
342342
rms_norm_eps=1e-5,
343343
hidden_size=32,
344344
intermediate_size=64,
@@ -382,7 +382,7 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation():
382382

383383
# Instantiate a dummy model
384384
config = transformers.models.mllama.configuration_mllama.MllamaConfig(
385-
torch_dtype=torch.bfloat16,
385+
dtype=torch.bfloat16,
386386
text_config=transformers.models.mllama.configuration_mllama.MllamaTextConfig(
387387
rms_norm_eps=1e-5,
388388
hidden_size=32,
@@ -533,7 +533,7 @@ def test_apply_liger_kernel_to_instance_for_llama4_for_causal_lm():
533533

534534
# Instantiate a dummy model
535535
config = transformers.models.llama4.configuration_llama4.Llama4TextConfig(
536-
torch_dtype=torch.bfloat16,
536+
dtype=torch.bfloat16,
537537
rms_norm_eps=1e-5,
538538
hidden_size=32,
539539
intermediate_size=64,
@@ -573,9 +573,9 @@ def test_apply_liger_kernel_to_instance_for_llama4_for_conditional_generation():
573573

574574
# Instantiate a dummy model
575575
config = transformers.models.llama4.configuration_llama4.Llama4Config(
576-
torch_dtype=torch.bfloat16,
576+
dtype=torch.bfloat16,
577577
text_config=transformers.models.llama4.configuration_llama4.Llama4TextConfig(
578-
torch_dtype=torch.bfloat16,
578+
dtype=torch.bfloat16,
579579
rms_norm_eps=1e-5,
580580
hidden_size=32,
581581
intermediate_size=64,
@@ -656,7 +656,7 @@ def test_apply_liger_kernel_to_instance_for_mistral():
656656
with patch("transformers.models.mistral.modeling_mistral"):
657657
# Instantiate a dummy model
658658
config = transformers.models.mistral.configuration_mistral.MistralConfig(
659-
torch_dtype=torch.bfloat16,
659+
dtype=torch.bfloat16,
660660
rms_norm_eps=1e-5,
661661
hidden_size=32,
662662
intermediate_size=64,
@@ -695,7 +695,7 @@ def test_apply_liger_kernel_to_instance_for_mixtral():
695695
with patch("transformers.models.mixtral.modeling_mixtral"):
696696
# Instantiate a dummy model
697697
config = transformers.models.mixtral.configuration_mixtral.MixtralConfig(
698-
torch_dtype=torch.bfloat16,
698+
dtype=torch.bfloat16,
699699
rms_norm_eps=1e-5,
700700
hidden_size=32,
701701
intermediate_size=64,
@@ -738,7 +738,7 @@ def test_apply_liger_kernel_to_instance_for_gemma():
738738
with patch("transformers.models.gemma.modeling_gemma"):
739739
# Instantiate a dummy model
740740
config = transformers.models.gemma.configuration_gemma.GemmaConfig(
741-
torch_dtype=torch.bfloat16,
741+
dtype=torch.bfloat16,
742742
rms_norm_eps=1e-5,
743743
hidden_size=32,
744744
intermediate_size=64,
@@ -777,7 +777,7 @@ def test_apply_liger_kernel_to_instance_for_gemma2():
777777
with patch("transformers.models.gemma2.modeling_gemma2"):
778778
# Instantiate a dummy model
779779
config = transformers.models.gemma2.configuration_gemma2.Gemma2Config(
780-
torch_dtype=torch.bfloat16,
780+
dtype=torch.bfloat16,
781781
rms_norm_eps=1e-5,
782782
hidden_size=32,
783783
intermediate_size=64,
@@ -827,7 +827,7 @@ def test_apply_liger_kernel_to_instance_for_paligemma():
827827

828828
# Instantiate a dummy model
829829
config = transformers.models.paligemma.configuration_paligemma.PaliGemmaConfig(
830-
torch_dtype=torch.bfloat16,
830+
dtype=torch.bfloat16,
831831
text_config={
832832
"num_hidden_layers": 2,
833833
"rms_norm_eps": 1e-5,
@@ -883,7 +883,7 @@ def test_apply_liger_kernel_to_instance_for_gemma3_text():
883883

884884
# Instantiate a dummy model
885885
config = transformers.models.gemma3.configuration_gemma3.Gemma3TextConfig(
886-
torch_dtype=torch.bfloat16,
886+
dtype=torch.bfloat16,
887887
rms_norm_eps=1e-5,
888888
hidden_size=32,
889889
intermediate_size=64,
@@ -939,7 +939,7 @@ def test_apply_liger_kernel_to_instance_for_gemma3_conditional_generation():
939939

940940
# Instantiate a dummy model
941941
text_config = transformers.models.gemma3.configuration_gemma3.Gemma3TextConfig(
942-
torch_dtype=torch.bfloat16,
942+
dtype=torch.bfloat16,
943943
rms_norm_eps=1e-5,
944944
hidden_size=32,
945945
intermediate_size=64,
@@ -1026,7 +1026,7 @@ def test_apply_liger_kernel_to_instance_for_qwen2():
10261026
with patch("transformers.models.qwen2.modeling_qwen2"):
10271027
# Instantiate a dummy model
10281028
config = transformers.models.qwen2.configuration_qwen2.Qwen2Config(
1029-
torch_dtype=torch.bfloat16,
1029+
dtype=torch.bfloat16,
10301030
rms_norm_eps=1e-5,
10311031
hidden_size=32,
10321032
intermediate_size=64,
@@ -1068,7 +1068,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3():
10681068

10691069
# Instantiate a dummy model
10701070
config = transformers.models.qwen3.configuration_qwen3.Qwen3Config(
1071-
torch_dtype=torch.bfloat16,
1071+
dtype=torch.bfloat16,
10721072
rms_norm_eps=1e-5,
10731073
hidden_size=32,
10741074
intermediate_size=64,
@@ -1110,7 +1110,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_moe():
11101110

11111111
# Instantiate a dummy model
11121112
config = transformers.models.qwen3_moe.configuration_qwen3_moe.Qwen3MoeConfig(
1113-
torch_dtype=torch.bfloat16,
1113+
dtype=torch.bfloat16,
11141114
rms_norm_eps=1e-5,
11151115
hidden_size=32,
11161116
intermediate_size=64,
@@ -1158,7 +1158,7 @@ def test_apply_liger_kernel_to_instance_for_qwen2_vl_for_conditional_generation(
11581158

11591159
# Instantiate a dummy model
11601160
config = transformers.models.qwen2_vl.configuration_qwen2_vl.Qwen2VLConfig(
1161-
torch_dtype=torch.bfloat16,
1161+
dtype=torch.bfloat16,
11621162
rms_norm_eps=1e-5,
11631163
hidden_size=32,
11641164
intermediate_size=48,
@@ -1227,7 +1227,7 @@ def test_apply_liger_kernel_to_instance_for_qwen2_vl():
12271227

12281228
# Instantiate a dummy model
12291229
config = transformers.models.qwen2_vl.configuration_qwen2_vl.Qwen2VLConfig(
1230-
torch_dtype=torch.bfloat16,
1230+
dtype=torch.bfloat16,
12311231
rms_norm_eps=1e-5,
12321232
hidden_size=32,
12331233
intermediate_size=48,
@@ -1294,7 +1294,7 @@ def test_apply_liger_kernel_to_instance_for_qwen2_vl_text():
12941294

12951295
# Instantiate a dummy model
12961296
config = transformers.models.qwen2_vl.configuration_qwen2_vl.Qwen2VLTextConfig(
1297-
torch_dtype=torch.bfloat16,
1297+
dtype=torch.bfloat16,
12981298
rms_norm_eps=1e-5,
12991299
hidden_size=32,
13001300
intermediate_size=48,
@@ -1347,7 +1347,7 @@ def test_apply_liger_kernel_to_instance_for_qwen2_5_vl():
13471347

13481348
# Instantiate a dummy model
13491349
config = transformers.models.qwen2_5_vl.configuration_qwen2_5_vl.Qwen2_5_VLConfig(
1350-
torch_dtype=torch.bfloat16,
1350+
dtype=torch.bfloat16,
13511351
rms_norm_eps=1e-5,
13521352
hidden_size=32,
13531353
intermediate_size=48,
@@ -1416,7 +1416,7 @@ def test_apply_liger_kernel_to_instance_for_qwen2_5_vl_for_conditional_generatio
14161416

14171417
# Instantiate a dummy model
14181418
config = transformers.models.qwen2_5_vl.configuration_qwen2_5_vl.Qwen2_5_VLConfig(
1419-
torch_dtype=torch.bfloat16,
1419+
dtype=torch.bfloat16,
14201420
rms_norm_eps=1e-5,
14211421
hidden_size=32,
14221422
intermediate_size=48,
@@ -1483,7 +1483,7 @@ def test_apply_liger_kernel_to_instance_for_qwen2_5_vl_text():
14831483

14841484
# Instantiate a dummy model
14851485
config = transformers.models.qwen2_5_vl.configuration_qwen2_5_vl.Qwen2_5_VLTextConfig(
1486-
torch_dtype=torch.bfloat16,
1486+
dtype=torch.bfloat16,
14871487
rms_norm_eps=1e-5,
14881488
hidden_size=32,
14891489
intermediate_size=48,
@@ -1528,7 +1528,7 @@ def test_apply_liger_kernel_to_instance_for_phi3():
15281528
with patch("transformers.models.phi3.modeling_phi3"):
15291529
# Instantiate a dummy model
15301530
config = transformers.models.phi3.configuration_phi3.Phi3Config(
1531-
torch_dtype=torch.bfloat16,
1531+
dtype=torch.bfloat16,
15321532
rms_norm_eps=1e-5,
15331533
hidden_size=32,
15341534
intermediate_size=64,
@@ -1570,7 +1570,7 @@ def test_apply_liger_kernel_to_instance_for_olmo2():
15701570

15711571
# Instantiate a dummy model
15721572
config = transformers.models.olmo2.configuration_olmo2.Olmo2Config(
1573-
torch_dtype=torch.bfloat16,
1573+
dtype=torch.bfloat16,
15741574
rms_norm_eps=1e-5,
15751575
hidden_size=32,
15761576
intermediate_size=64,
@@ -1616,7 +1616,7 @@ def test_apply_liger_kernel_to_instance_for_glm4():
16161616

16171617
# Instantiate a dummy model
16181618
config = transformers.models.glm4.configuration_glm4.Glm4Config(
1619-
torch_dtype=torch.bfloat16,
1619+
dtype=torch.bfloat16,
16201620
rms_norm_eps=1e-5,
16211621
hidden_size=32,
16221622
intermediate_size=64,
@@ -1664,7 +1664,7 @@ def test_apply_liger_kernel_to_instance_for_glm4v():
16641664

16651665
# Instantiate a dummy model
16661666
config = transformers.models.glm4v.configuration_glm4v.Glm4vConfig(
1667-
torch_dtype=torch.bfloat16,
1667+
dtype=torch.bfloat16,
16681668
text_config={
16691669
"num_hidden_layers": 2,
16701670
"rms_norm_eps": 1e-5,
@@ -1734,7 +1734,7 @@ def test_apply_liger_kernel_to_instance_for_glm4v_moe():
17341734

17351735
# Instantiate a dummy model
17361736
config = transformers.models.glm4v_moe.configuration_glm4v_moe.Glm4vMoeConfig(
1737-
torch_dtype=torch.bfloat16,
1737+
dtype=torch.bfloat16,
17381738
hidden_size=32,
17391739
num_attention_heads=4,
17401740
num_key_value_heads=2,
@@ -1837,7 +1837,7 @@ def test_apply_liger_kernel_to_instance_for_smollm3():
18371837
with patch("transformers.models.smollm3.modeling_smollm3"):
18381838
# Instantiate a dummy model
18391839
config = transformers.models.smollm3.configuration_smollm3.SmolLM3Config(
1840-
torch_dtype=torch.bfloat16,
1840+
dtype=torch.bfloat16,
18411841
rms_norm_eps=1e-5,
18421842
hidden_size=32,
18431843
intermediate_size=64,

0 commit comments

Comments
 (0)