forked from kyutai-labs/moshi-finetune
-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtrain.py
More file actions
360 lines (299 loc) · 10.8 KB
/
train.py
File metadata and controls
360 lines (299 loc) · 10.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
import dataclasses
import logging
import os
import pprint
import shutil
from contextlib import ExitStack
from pathlib import Path
import fire
import torch.cuda
import torch.distributed as dist
from torch.optim import AdamW, lr_scheduler
# from torch.profiler import ProfilerActivity, profile
from finetune.args import TrainArgs
from finetune.checkpointing import Checkpointer
from finetune.data.data_loader import build_data_loader
from finetune.data.interleaver import InterleavedTokenizer, Interleaver
from finetune.distributed import (
BACKEND,
avg_aggregate,
get_rank,
get_world_size,
is_torchrun,
set_device,
)
from finetune.eval import evaluate
from finetune.loss import compute_loss_with_mask
from finetune.mixed_precision import (
downcast_mixed_precision,
prepare_mixed_precision,
upcast_mixed_precision,
)
from finetune.monitoring.metrics_logger import (
MetricsLogger,
eval_log_msg,
get_eval_logs,
get_train_logs,
train_log_msg,
)
from finetune.monitoring.utils import set_logger
from finetune.utils import TrainState, logged_closing, set_random_seed
from finetune.wrapped_model import get_fsdp_model
from moshi.models import loaders
logger = logging.getLogger("train")
def main_logger_info(message: str) -> None:
if get_rank() == 0:
logger.info(message)
def train(config: str):
args: TrainArgs = TrainArgs.load(config, drop_extra_fields=False)
set_logger(logging.INFO)
with ExitStack() as exit_stack:
_train(args, exit_stack)
logger.info("Closed everything!")
def _train(args: TrainArgs, exit_stack: ExitStack):
# 1. Initial setup and checks
set_random_seed(args.seed)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Init NCCL
if "LOCAL_RANK" in os.environ:
set_device()
logger.info("Going to init comms...")
dist.init_process_group(backend=BACKEND)
else:
logger.error(
"PyTorch environment is not correctly initialized. This message should only be displayed when testing."
)
# 2. Init run dir
main_logger_info(f"Run dir: {args.run_dir}")
run_dir = Path(args.run_dir)
if is_torchrun():
if run_dir.exists() and not args.overwrite_run_dir:
raise RuntimeError(
f"Run dir {run_dir} already exists. Make sure to either rename `run_dir` or remove {run_dir}."
)
elif run_dir.exists():
main_logger_info(f"Removing run dir {run_dir}...")
shutil.rmtree(run_dir)
if args.full_finetuning:
assert not args.lora.enable, "LoRA should not be enabled for full finetuning."
else:
assert args.lora.enable, "LoRA should be enabled for partial finetuning"
dist.barrier()
run_dir.mkdir(exist_ok=True, parents=True)
args_path = run_dir / "args.yaml"
if not args_path.exists():
args.save(args_path)
main_logger_info(f"TrainArgs: {pprint.pformat(dataclasses.asdict(args))}")
# 3. Get loggers
metrics_logger: MetricsLogger = MetricsLogger(
run_dir,
tag="train",
is_master=get_rank() == 0,
wandb_args=args.wandb,
config=dataclasses.asdict(args),
)
exit_stack.enter_context(logged_closing(metrics_logger, "metrics_logger"))
eval_logger: MetricsLogger = MetricsLogger(
run_dir,
tag="eval",
is_master=get_rank() == 0,
wandb_args=args.wandb,
config=dataclasses.asdict(args),
)
exit_stack.enter_context(logged_closing(eval_logger, "eval_logger"))
# 4.1 Load function calling audio encoder and tokenizer
main_logger_info("Loading Mimi and Moshi...")
checkpoint_info = loaders.CheckpointInfo.from_hf_repo(
hf_repo=args.moshi_paths.hf_repo_id,
moshi_weights=args.moshi_paths.moshi_path,
mimi_weights=args.moshi_paths.mimi_path,
tokenizer=args.moshi_paths.tokenizer_path,
config_path=args.moshi_paths.config_path,
)
lm_config = (
loaders._lm_kwargs
if checkpoint_info.raw_config is None
else checkpoint_info.raw_config
)
lm_config["lora"] = args.lora.enable
lm_config["lora_rank"] = args.lora.rank
lm_config["lora_scaling"] = args.lora.scaling
mimi = checkpoint_info.get_mimi(device="cuda")
mimi.eval()
for p in mimi.parameters():
p.requires_grad = False
# 4.2 Load and shard model, prepare interleaver for audio/text tokens.
model = get_fsdp_model(args, checkpoint_info)
spm = checkpoint_info.get_text_tokenizer()
interleaver = Interleaver(
spm,
mimi.frame_rate,
model.text_padding_token_id,
model.end_of_text_padding_id,
model.zero_token_id,
keep_main_only=True,
)
interleaved_tokenizer = InterleavedTokenizer(
mimi, interleaver, duration_sec=args.duration_sec
)
# 5. Load data loaders
data_loader = build_data_loader(
instruct_tokenizer=interleaved_tokenizer,
args=args.data,
batch_size=args.batch_size,
seed=args.seed,
rank=get_rank(), # DDP rank
world_size=get_world_size(), # DDP world_size
is_eval=False,
)
if args.do_eval:
eval_data_loader = build_data_loader(
instruct_tokenizer=interleaved_tokenizer,
args=args.data,
batch_size=args.batch_size,
seed=None,
rank=get_rank(), # DDP rank
world_size=get_world_size(), # DDP world_size
is_eval=True,
)
# 6. Load model
# Define mixed precision
param_dtype = getattr(torch, args.param_dtype)
optim_dtype = torch.float32
assert args.lora is not None, "`args.lora` should be set to a valid value."
# 7. Load optimizer
optimizer = AdamW(
model.parameters(),
lr=args.optim.lr,
betas=(0.9, 0.95),
eps=1e-08,
weight_decay=args.optim.weight_decay,
)
scheduler = lr_scheduler.OneCycleLR(
optimizer,
max_lr=args.optim.lr,
total_steps=args.max_steps,
pct_start=args.optim.pct_start,
)
state = TrainState(args.max_steps)
# 8. Initialize checkpointer
if args.do_ckpt:
checkpointer = Checkpointer(
model=model,
state=state,
config=lm_config,
run_dir=run_dir,
optimizer=optimizer,
num_ckpt_keep=args.num_ckpt_keep,
full_finetuning=args.full_finetuning,
)
# 9. Prepare mixed precision
prepare_mixed_precision(
model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype
)
# 11. train!
model.train()
torch.cuda.empty_cache()
while state.step < args.max_steps:
state.start_step()
is_last_step = state.step == args.max_steps
optimizer.zero_grad()
loss = torch.tensor([0.0], device="cuda")
n_batch_tokens: int = 0
n_real_tokens: int = 0
for i in range(args.num_microbatches):
batch = next(data_loader)
codes = batch.codes
condition_tensors = None
if batch.condition_attributes is not None:
condition_tensors = model.condition_provider.prepare(
batch.condition_attributes
)
# forward / backward
output = model(codes=codes, condition_tensors=condition_tensors)
text_loss = compute_loss_with_mask(
output.text_logits,
codes[:, : model.audio_offset],
output.text_mask,
mode="text",
text_padding_weight=args.text_padding_weight,
text_padding_ids={
model.text_padding_token_id,
model.end_of_text_padding_id,
},
)
audio_loss = compute_loss_with_mask(
output.logits,
codes[:, model.audio_offset : model.audio_offset + model.dep_q],
output.mask,
mode="audio",
first_codebook_weight_multiplier=args.first_codebook_weight_multiplier,
)
mb_loss = text_loss + audio_loss
mb_loss.backward()
loss += mb_loss.detach()
n_batch_tokens += output.text_mask.numel() + output.mask.numel()
n_real_tokens += (
torch.sum(output.text_mask).item() + torch.sum(output.mask).item()
)
if i < args.num_microbatches - 1:
# synchronize CUDA to re-run backward
assert args.num_microbatches > 1 # should not happen
torch.cuda.synchronize()
if args.num_microbatches > 1:
loss /= args.num_microbatches
for p in model.parameters():
if p.requires_grad:
assert p.grad is not None
p.grad.div_(args.num_microbatches)
# upcast params for optimizer update
upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype)
# clip grad norm
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
# optimizer step
optimizer.step()
# downcast params for forward & backward
downcast_mixed_precision(model.parameters(), param_dtype=param_dtype)
last_lr = scheduler.get_last_lr()[0]
scheduler.step()
# Host sync
loss_item = loss.item()
avg_loss = avg_aggregate(loss_item)
if args.do_eval and (
(args.eval_freq > 0 and state.step % args.eval_freq == 0) or is_last_step
):
# write perplexity to state
evaluate(model, eval_data_loader, state, args)
eval_logs = get_eval_logs(
state.step,
avg_loss,
state.this_eval_perplexity,
state.this_eval_loss,
)
main_logger_info(eval_log_msg(eval_logs))
eval_logger.log(eval_logs, step=state.step)
# Timing
state.end_step(n_batch_tokens)
if state.step % args.log_freq == 0:
train_logs = get_train_logs(
state,
avg_loss,
n_real_tokens,
last_lr,
torch.cuda.max_memory_allocated(),
torch.cuda.memory_allocated(),
args,
)
main_logger_info(train_log_msg(state, logs=train_logs, loss=avg_loss))
metrics_logger.log(train_logs, step=state.step)
if args.do_ckpt and (
(args.ckpt_freq > 0 and state.step % args.ckpt_freq == 0) or is_last_step
):
checkpointer.save_checkpoint(
save_only_lora=not args.full_finetuning and args.save_adapters,
dtype=param_dtype,
)
main_logger_info("done!")
if __name__ == "__main__":
"""See README.md for usage."""
fire.Fire(train)