-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtrain_fsdp2.py
More file actions
272 lines (231 loc) · 10.3 KB
/
train_fsdp2.py
File metadata and controls
272 lines (231 loc) · 10.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import os
import sys
import types
from dataclasses import dataclass, field, fields
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
from transformers import AutoConfig, AutoModelForCausalLM
from mindspeed_llm.fsdp2.models.model_factory import ModelFactory
from mindspeed_llm.fsdp2.optim.optimizer import OptimizerFactory
from mindspeed_llm.fsdp2.optim.scheduler import SchedulerFactory
from mindspeed_llm.fsdp2.checkpoint.checkpoint_manager import CheckpointManager
from mindspeed_llm.fsdp2.train.trainer import Trainer
from mindspeed_llm.fsdp2.data.data_factory import DataFactory
from mindspeed_llm.fsdp2.data.tokenizer import TokenizerFactory
from mindspeed_llm.fsdp2.data.template import get_template_and_fix_tokenizer
from mindspeed_llm.fsdp2.utils.logging import setup_global_logging, get_logger
from mindspeed_llm.fsdp2.utils.arguments import (
ModelArguments, DataArguments, ParallelArguments, TrainingArguments, OptimizationArguments, fsdp2_parse_args
)
from mindspeed_llm.fsdp2.utils.global_vars import set_args
from mindspeed_llm.fsdp2.utils.train_monitor import TrainMonitor
from mindspeed.fsdp.utils.device import set_accelerator_compatible
from mindspeed.fsdp.utils.random import set_seed
from mindspeed.fsdp.utils.torch_patch import apply_hccl_premul_sum_patch
from mindspeed_llm.training.utils import auto_coverage
logger = get_logger(__name__)
# ==============================================================================
# [Arguments Definition] Arguments Class for MindSpeed FSDP Scheme
# ==============================================================================
@dataclass
class Arguments:
"""Root arguments class containing model, data, parallel, and training arguments."""
model: ModelArguments = field(default_factory=ModelArguments)
data: DataArguments = field(default_factory=DataArguments)
parallel: ParallelArguments = field(default_factory=ParallelArguments)
training: TrainingArguments = field(default_factory=TrainingArguments)
optimization: OptimizationArguments = field(default_factory=OptimizationArguments)
# ==============================================================================
# AutoTrainer
# ==============================================================================
class MindSpeedAutoTrainer:
"""
AutoTrainer: Dependency Injection Container.
Based on FSDP2 Arguments (HfArgumentParser style).
"""
def __init__(self):
# 1. Parse arguments
self._parse_args()
# 2. Initialize distributed environment
self._initialize(seed=self.training_args.seed)
self.rank = torch.distributed.get_rank()
self._print_parsed_args()
# 3. Build components
self.model = self._build_model()
self.tokenizer = self._build_tokenizer()
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args)
self.data_manager = self._build_data_manager(self.tokenizer, self.template)
self.optimizer = self._build_optimizer(self.model)
self.lr_scheduler = self._build_scheduler(self.optimizer)
self.checkpoint_manager = self._build_checkpointer()
self.train_monitor = self._build_monitor()
# 4. Dependency Injection
self.trainer = Trainer(
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
data_manager=self.data_manager,
args=self.training_args,
parallel_args=self.parallel_args,
optimization_args=self.optimization_args,
data_args=self.data_args,
ckpt_manager=self.checkpoint_manager,
monitor=self.train_monitor,
tokenizer=self.tokenizer,
)
@staticmethod
def _initialize(seed: int):
"""
Static initialization method: Receives external seed and local_rank,
avoiding dependency on hardcoding or self.
"""
set_accelerator_compatible(torch.npu)
apply_hccl_premul_sum_patch()
setup_global_logging(level="INFO")
# --- 1. Handle Local Rank (Device Index) ---
# Logic: Prioritize environment variables (injected by torchrun/accelerate),
# then fallback to arguments, and finally default to 0.
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1:
target_device_index = env_local_rank
else:
# Fallback for single-node single-card or incorrect configuration
target_device_index = 0
os.environ["LOCAL_RANK"] = str(target_device_index)
# Set the NPU device for the current process
torch.accelerator.set_device_index(target_device_index)
torch.npu.set_device(target_device_index)
# --- 2. Dynamically set random seed ---
# MindSpeed's set_seed usually handles offset for different ranks.
set_seed(seed, set_deterministic=True)
# --- 3. Initialize distributed process group ---
# Simple fault tolerance: Manual injection for single-script runs (non-torchrun)
if "RANK" not in os.environ:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
# Ensure LOCAL_RANK is also set
if "LOCAL_RANK" not in os.environ:
os.environ["LOCAL_RANK"] = "0"
# Get final global rank and world size from environment variables
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend="hccl",
rank=rank,
world_size=world_size
)
def train(self):
self.trainer.train(resume_from_checkpoint=self.training_args.resume_from_checkpoint)
def _parse_args(self):
root_args = fsdp2_parse_args(Arguments)
self.model_args = root_args.model
self.data_args = root_args.data
self.parallel_args = root_args.parallel
self.training_args = root_args.training
self.optimization_args = root_args.optimization
self.args = types.SimpleNamespace(**{
k: v for ns in [root_args.model, root_args.data, root_args.parallel, root_args.training, root_args.optimization]
for k, v in ns.__dict__.items()
})
set_args(self.args)
def _print_parsed_args(self):
arg_modules = [
("ModelArguments", self.model_args),
("DataArguments", self.data_args),
("ParallelArguments", self.parallel_args),
("TrainingArguments", self.training_args)
]
for module_name, arg_instance in arg_modules:
logger.info_plain_rank0(f"\n {module_name}")
logger.info_plain_rank0("-" * 60)
for f in fields(arg_instance):
val = getattr(arg_instance, f.name)
logger.info_plain_rank0(f" {f.name:<30} {val if val is not None else 'None'}")
# =========================================================================
# Component Builders
# =========================================================================
def _build_tokenizer(self):
logger.info_rank0("> Building Tokenizer...")
return TokenizerFactory.create(self.model_args)
def _build_model(self):
logger.info_rank0("> Building FSDP2 Model...")
return ModelFactory.create(self.model_args, self.parallel_args)
def _build_optimizer(self, model):
logger.info_rank0("> Building Optimizer...")
return OptimizerFactory.create(
model=model,
ep_size=self.parallel_args.ep_size,
lr=self.training_args.lr,
optimizer_type=self.training_args.optimizer,
weight_decay=self.training_args.weight_decay,
betas=(self.training_args.adam_beta1, self.training_args.adam_beta2),
adam_epsilon=self.training_args.adam_epsilon
)
def _build_scheduler(self, optimizer):
logger.info_rank0("> Building LR Scheduler...")
# Determine max steps
if self.training_args.max_steps > 0:
max_steps = self.training_args.max_steps
else:
# If in Epoch mode, estimate or provide a large number temporarily.
# While FSDP2Trainer calculates total_steps more accurately,
# we rely on args.max_steps or a default large value here for factory construction.
max_steps = 100000
return SchedulerFactory.create(
optimizer=optimizer,
train_steps=max_steps,
lr=self.training_args.lr,
lr_decay_style=self.training_args.lr_scheduler_type,
lr_warmup_ratio=self.training_args.warmup_ratio,
lr_min=self.training_args.min_lr
)
def _build_data_manager(self, tokenizer, template):
logger.info_rank0("> Building DataFactory...")
return DataFactory.create(
data_manager_type=self.data_args.data_manager_type,
model_args=self.model_args,
data_args=self.data_args,
parallel_args=self.parallel_args,
training_args=self.training_args,
stage="sft",
tokenizer=tokenizer,
template=template
)
def _build_monitor(self):
logger.info_rank0("> Building Monitor...")
hf_config = AutoConfig.from_pretrained(
self.model_args.model_name_or_path,
trust_remote_code=True
)
return TrainMonitor(self.training_args, hf_config)
def _build_checkpointer(self):
logger.info_rank0("> Building Checkpointer...")
return CheckpointManager
# ==============================================================================
# [Facade] Unified AutoTrainer
# This is the single public entry point responsible for logic dispatch.
# ==============================================================================
class AutoTrainer:
"""
Unified entry point for Training.
Dispatches to MindSpeedAutoTrainer (New) or McoreAutoTrainer (Old) based on configuration.
"""
def __init__(self):
# Strategy Dispatch: Prioritize environment variable TRAINING_BACKEND
# To run MindSpeed FSDP code, set: export TRAINING_BACKEND=mindspeed_fsdp
logger.info_rank0(f">>> [AutoTrainer] Initializing MindSpeed FSDP backend...")
self.trainer = MindSpeedAutoTrainer()
def train(self):
"""Delegate to the implementation"""
self.trainer.train()
@auto_coverage
def main():
trainer = AutoTrainer()
trainer.train()
# ==============================================================================
# [Entry Point]
# ==============================================================================
if __name__ == "__main__":
main()