forked from ZHZisZZ/dllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsft.py
More file actions
115 lines (96 loc) · 3.92 KB
/
sft.py
File metadata and controls
115 lines (96 loc) · 3.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Local users
------------
- 1 GPU (4bit quant & LoRA, useful for testing):
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/llada/sft.py \
--load_in_4bit True --lora True
- 8 GPUs (FSDP):
accelerate launch \
--config_file scripts/accelerate_configs/fsdp.yaml \
examples/llada/sft.py
Slurm users
# Note: run `mkdir logs` before running sbatch; and adjust
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
------------
- 1 Node, 8 GPUs (FSDP):
sbatch --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/sft.py"
- 2 Nodes, 16 GPUs (FSDP):
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/sft.py"
"""
import os
from dataclasses import dataclass, field
from functools import partial
import transformers
import accelerate
import dllm
@dataclass
class ModelArguments(dllm.utils.ModelArguments):
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Base"
@dataclass
class DataArguments(dllm.utils.DataArguments):
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
load_preprocessed_data: bool = False
mask_prompt_loss: bool = field(
default=True,
metadata={"help": "Whether to mask the loss on the prompt tokens"},
)
@dataclass
class TrainingArguments(dllm.utils.TrainingArguments):
output_dir: str = "models/LLaDA-8B-SFT/tulu-3-sft-mixture[train:10000,test:1000]"
group_by_length: bool = True
def train():
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
dllm.utils.print_args_main(model_args, data_args, training_args)
dllm.utils.initial_training_setup(model_args, data_args, training_args)
# ----- Model ------------------------------------------------------------------
model = dllm.utils.get_model(model_args=model_args)
# ----- Tokenizer --------------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
# ----- Dataset ----------------------------------------------------------------
with accelerate.PartialState().local_main_process_first():
dataset = dllm.data.load_sft_dataset(
data_args.dataset_args,
load_preprocessed_data=data_args.load_preprocessed_data,
)
if not data_args.load_preprocessed_data:
map_fn = partial(
dllm.utils.default_sft_map_fn,
tokenizer=tokenizer,
mask_prompt_loss=data_args.mask_prompt_loss,
)
dataset = dataset.map(map_fn, num_proc=data_args.num_proc)
# truncate / filter long sequences if needed
dataset = dllm.utils.post_process_dataset(dataset, data_args)
# ----- Training --------------------------------------------------------------
accelerate.PartialState().wait_for_everyone()
dllm.utils.print_main("start training...")
trainer = dllm.core.trainers.MDLMTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
data_collator=dllm.utils.NoAttentionMaskCollator(
tokenizer,
return_tensors="pt",
padding=True,
label_pad_token_id=tokenizer.pad_token_id, # finetune on padding <eos_token>
),
)
trainer.train()
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
trainer.processing_class.save_pretrained(
os.path.join(training_args.output_dir, "checkpoint-final")
)
if __name__ == "__main__":
train()