Skip to content

Commit a6d28d1

Browse files
authored
feat: add glm and glm4 multipack and cce (axolotl-ai-cloud#2546)
* feat: add glm and glm4 multipack * feat: add glm4 example * feat: add cce for glm
1 parent 32e335d commit a6d28d1

File tree

5 files changed

+129
-0
lines changed

5 files changed

+129
-0
lines changed

examples/glm4/qlora-32b.yaml

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
base_model: THUDM/GLM-4-32B-0414
2+
# Automatically upload checkpoint and final model to HF
3+
# hub_model_id: username/custom_model_name
4+
5+
load_in_4bit: true
6+
7+
datasets:
8+
- path: teknium/GPT4-LLM-Cleaned
9+
type: alpaca
10+
dataset_prepared_path: last_run_prepared
11+
val_set_size: 0
12+
output_dir: ./outputs/qlora-out
13+
14+
adapter: qlora
15+
lora_model_dir:
16+
17+
sequence_len: 2048
18+
sample_packing: true
19+
eval_sample_packing: true
20+
pad_to_sequence_len: true
21+
22+
lora_r: 16
23+
lora_alpha: 32
24+
lora_dropout: 0.05
25+
lora_target_modules:
26+
- gate_proj
27+
- down_proj
28+
- up_proj
29+
- q_proj
30+
- v_proj
31+
- k_proj
32+
- o_proj
33+
34+
wandb_project:
35+
wandb_entity:
36+
wandb_watch:
37+
wandb_name:
38+
wandb_log_model:
39+
40+
gradient_accumulation_steps: 2
41+
micro_batch_size: 2
42+
num_epochs: 1
43+
optimizer: adamw_8bit
44+
lr_scheduler: cosine
45+
learning_rate: 0.0002
46+
47+
bf16: auto
48+
tf32: false
49+
50+
gradient_checkpointing: true
51+
resume_from_checkpoint:
52+
logging_steps: 1
53+
flash_attention: true
54+
55+
loss_watchdog_threshold: 5.0
56+
loss_watchdog_patience: 3
57+
58+
warmup_steps: 10
59+
evals_per_epoch: 1
60+
saves_per_epoch: 1
61+
weight_decay: 0.0
62+
special_tokens:

src/axolotl/integrations/cut_cross_entropy/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ cut_cross_entropy: true
4747
- qwen2
4848
- cohere
4949
- cohere2
50+
- glm
51+
- glm4
5052
5153
## Citation
5254
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""GLM 4 patch. GLM family inherits from Llama."""
2+
3+
from types import MethodType
4+
5+
import transformers
6+
from cut_cross_entropy.transformers.utils import (
7+
PatchOptions,
8+
TransformersModelT,
9+
)
10+
11+
12+
def patch_glm(
13+
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
14+
patch_options: PatchOptions,
15+
) -> TransformersModelT | None:
16+
17+
# Set the _PATCH_OPTS in the llama patch file
18+
import cut_cross_entropy.transformers.llama as llama_patch
19+
20+
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
21+
22+
from cut_cross_entropy.transformers.llama import cce_forward
23+
from transformers.models.glm import modeling_glm
24+
25+
if isinstance(maybe_model, transformers.PreTrainedModel):
26+
assert isinstance(
27+
maybe_model, modeling_glm.GlmForCausalLM
28+
), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}."
29+
maybe_model.forward = MethodType(cce_forward, maybe_model)
30+
return maybe_model
31+
32+
modeling_glm.GlmForCausalLM.forward = cce_forward
33+
return None
34+
35+
36+
def patch_glm4(
37+
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
38+
patch_options: PatchOptions,
39+
) -> TransformersModelT | None:
40+
41+
# Set the _PATCH_OPTS in the llama patch file
42+
import cut_cross_entropy.transformers.llama as llama_patch
43+
44+
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
45+
46+
from cut_cross_entropy.transformers.llama import cce_forward
47+
from transformers.models.glm4 import modeling_glm4
48+
49+
if isinstance(maybe_model, transformers.PreTrainedModel):
50+
assert isinstance(
51+
maybe_model, modeling_glm4.Glm4ForCausalLM
52+
), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}."
53+
maybe_model.forward = MethodType(cce_forward, maybe_model)
54+
return maybe_model
55+
56+
modeling_glm4.Glm4ForCausalLM.forward = cce_forward
57+
return None

src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
patch_gemma3,
2121
patch_gemma3_text,
2222
)
23+
from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
24+
patch_glm,
25+
patch_glm4,
26+
)
2327
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
2428
patch_llama4,
2529
patch_llama4_text,
@@ -45,6 +49,8 @@
4549
"qwen2": patch_qwen2,
4650
"cohere": patch_cohere,
4751
"cohere2": patch_cohere2,
52+
"glm": patch_glm,
53+
"glm4": patch_glm4,
4854
}
4955

5056

src/axolotl/monkeypatch/multipack.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
"starcoder2",
3232
"deepseek_v2",
3333
"deepseek_v3",
34+
"glm",
35+
"glm4",
3436
]
3537

3638

0 commit comments

Comments
 (0)