forked from ZHZisZZ/dllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpt.py
More file actions
167 lines (146 loc) · 6.04 KB
/
pt.py
File metadata and controls
167 lines (146 loc) · 6.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""
Local users
------------
- 1 GPU (4bit quant & LoRA, useful for testing):
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/llada/pt.py \
--load_in_4bit True --lora True
- 8 GPUs (FSDP):
accelerate launch \
--config_file scripts/accelerate_configs/fsdp.yaml \
examples/llada/pt.py
Slurm users
# Note: run `mkdir logs` before running sbatch; and adjust
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
------------
- 24 Nodes, 192 GPUs (FSDP):
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/pt.py"
"""
import os
import functools
from dataclasses import dataclass, field
import torch
import transformers
import accelerate
import dllm
@dataclass
class ModelArguments(dllm.utils.ModelArguments):
# Uses only the configuration from model_name_or_path to initialize the model from scratch
model_name_or_path: str = (
"GSAI-ML/LLaDA-8B-Base" # "inclusionAI/LLaDA-MoE-7B-A1B-Base"
)
@dataclass
class DataArguments(dllm.utils.DataArguments):
dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]"
text_field: str = "text"
streaming: bool = True
drop_tail: bool = True
insert_eos: bool = field(
default=True,
metadata={
"help": "False when adjacent samples from the datasets are semantically coherent."
},
)
random_length_ratio: float = field(
default=0.01,
metadata={
"help": (
"The probability of randomly cut sequences during training. "
"See https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md#pre-training for reference."
)
},
)
@dataclass
class TrainingArguments(dllm.utils.TrainingArguments):
output_dir: str = (
"models/LLaDA-8B-PT/dclm-baseline-1.0[train:10_000_000,test:10_000]"
)
learning_rate: float = 3e-4
max_steps: int = 2_000
per_device_train_batch_size: int = 4
gradient_accumulation_steps: int = 4
eval_steps: float = 0.05
save_steps: float = 0.05
def train():
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# necessary for streaming dataset
if data_args.streaming: training_args.accelerator_config.dispatch_batches = False
dllm.utils.print_args_main(model_args, data_args, training_args)
dllm.utils.initial_training_setup(model_args, data_args, training_args)
# ----- Model ------------------------------------------------------------------
# initialize model weights from scratch
config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path)
with dllm.utils.init_device_context_manager():
model = transformers.AutoModel.from_config(
config, dtype=torch.bfloat16, init_params=True
)
# ----- Tokenizer --------------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
# ----- Optional PEFT: LoRA ----------------------------------------------------
model = dllm.utils.load_peft(model=model, model_args=model_args)
# ----- Dataset ----------------------------------------------------------------
with accelerate.PartialState().local_main_process_first():
dataset = dllm.data.load_pt_dataset(
data_args.dataset_args,
streaming=data_args.streaming,
)
dataset = dataset.map(
functools.partial(
dllm.utils.tokenize_and_group,
tokenizer=tokenizer,
text_field=data_args.text_field,
seq_length=data_args.max_length,
insert_eos=data_args.insert_eos,
drop_tail=data_args.drop_tail),
batched=True,
num_proc=None if data_args.streaming else data_args.num_proc,
remove_columns=dataset["train"].column_names,
)
if data_args.streaming: dataset = dataset.shuffle(seed=training_args.seed)
# ----- Training --------------------------------------------------------------
@dataclass
class LLaDAPTCollator(transformers.DataCollatorForSeq2Seq):
# Reference: https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md#pre-training
# By default, 1% of the pre-training data are truncated to a random length
random_length_ratio: float = 0.01
def __call__(self, features, return_tensors=None):
outputs = super().__call__(features, return_tensors)
if torch.rand(1) < self.random_length_ratio:
random_length = torch.randint(
1, outputs["input_ids"].shape[1] + 1, (1,)
)
for key in ["input_ids", "labels", "attention_mask"]:
if key in outputs: outputs[key] = outputs[key][:, :random_length]
# Check if attention_mask is all ones and set it to None
if torch.all(outputs["attention_mask"] == 1):
outputs.pop("attention_mask")
return outputs
accelerate.PartialState().wait_for_everyone()
dllm.utils.print_main("start training...")
trainer = dllm.core.trainers.MDLMTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
data_collator=LLaDAPTCollator(
tokenizer,
return_tensors="pt",
padding=True,
random_length_ratio=data_args.random_length_ratio,
),
)
trainer.train()
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
trainer.processing_class.save_pretrained(
os.path.join(training_args.output_dir, "checkpoint-final")
)
if __name__ == "__main__":
train()