Skip to content

Commit 6c00cf0

Browse files
sayakpaula-r-r-o-w
andauthored
fix: resuming from a checkpoint when using deepspeed. (#38)
* fix: resuming from a checkpoint when using deepspeed. * remove changes to prepare_dataset.py * propagate to others. * tackle gradnorm. --------- Co-authored-by: Aryan <[email protected]>
1 parent f0d9908 commit 6c00cf0

File tree

4 files changed

+88
-53
lines changed

4 files changed

+88
-53
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ check_dirs := training tests
44

55
quality:
66
ruff check $(check_dirs)
7-
ruff format --check $(check_dirs) setup.py
7+
ruff format --check $(check_dirs)
88

99
style:
1010
ruff check $(check_dirs) --fix

training/cogvideox_image_to_video_lora.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import diffusers
2727
import torch
2828
import transformers
29+
import wandb
2930
from accelerate import Accelerator, DistributedType
3031
from accelerate.logging import get_logger
3132
from accelerate.utils import (
@@ -52,8 +53,6 @@
5253
from tqdm.auto import tqdm
5354
from transformers import AutoTokenizer, T5EncoderModel
5455

55-
import wandb
56-
5756

5857
from args import get_args # isort:skip
5958
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
@@ -385,13 +384,15 @@ def save_model_hook(models, weights, output_dir):
385384
transformer_lora_layers_to_save = None
386385

387386
for model in models:
388-
if isinstance(model, type(unwrap_model(transformer))):
387+
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
388+
model = unwrap_model(model)
389389
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
390390
else:
391391
raise ValueError(f"unexpected save model: {model.__class__}")
392392

393393
# make sure to pop weight so that corresponding model is not saved again
394-
weights.pop()
394+
if weights:
395+
weights.pop()
395396

396397
CogVideoXImageToVideoPipeline.save_lora_weights(
397398
output_dir,
@@ -401,13 +402,20 @@ def save_model_hook(models, weights, output_dir):
401402
def load_model_hook(models, input_dir):
402403
transformer_ = None
403404

404-
while len(models) > 0:
405-
model = models.pop()
405+
# This is a bit of a hack but I don't know any other solution.
406+
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
407+
while len(models) > 0:
408+
model = models.pop()
406409

407-
if isinstance(model, type(unwrap_model(transformer))):
408-
transformer_ = model
409-
else:
410-
raise ValueError(f"Unexpected save model: {model.__class__}")
410+
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
411+
transformer_ = unwrap_model(model)
412+
else:
413+
raise ValueError(f"Unexpected save model: {unwrap_model(model).__class__}")
414+
else:
415+
transformer_ = CogVideoXTransformer3DModel.from_pretrained(
416+
args.pretrained_model_name_or_path, subfolder="transformer"
417+
)
418+
transformer_.add_adapter(transformer_lora_config)
411419

412420
lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir)
413421

@@ -795,12 +803,15 @@ def load_model_hook(models, input_dir):
795803
logger.info(f"Saved state to {save_path}")
796804

797805
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
798-
logs = {
799-
"loss": loss.detach().item(),
800-
"lr": last_lr,
801-
"gradient_norm_before_clip": gradient_norm_before_clip,
802-
"gradient_norm_after_clip": gradient_norm_after_clip,
803-
}
806+
logs = {"loss": loss.detach().item(), "lr": last_lr}
807+
# gradnorm + deepspeed: https://github.com/microsoft/DeepSpeed/issues/4555
808+
if accelerator.distributed_type != DistributedType.DEEPSPEED:
809+
logs.update(
810+
{
811+
"gradient_norm_before_clip": gradient_norm_before_clip,
812+
"gradient_norm_after_clip": gradient_norm_after_clip,
813+
}
814+
)
804815
progress_bar.set_postfix(**logs)
805816
accelerator.log(logs, step=global_step)
806817

training/cogvideox_text_to_video_lora.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import diffusers
2626
import torch
2727
import transformers
28+
import wandb
2829
from accelerate import Accelerator, DistributedType
2930
from accelerate.logging import get_logger
3031
from accelerate.utils import (
@@ -51,8 +52,6 @@
5152
from tqdm.auto import tqdm
5253
from transformers import AutoTokenizer, T5EncoderModel
5354

54-
import wandb
55-
5655

5756
from args import get_args # isort:skip
5857
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
@@ -358,13 +357,15 @@ def save_model_hook(models, weights, output_dir):
358357
transformer_lora_layers_to_save = None
359358

360359
for model in models:
361-
if isinstance(model, type(unwrap_model(transformer))):
360+
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
361+
model = unwrap_model(model)
362362
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
363363
else:
364364
raise ValueError(f"unexpected save model: {model.__class__}")
365365

366366
# make sure to pop weight so that corresponding model is not saved again
367-
weights.pop()
367+
if weights:
368+
weights.pop()
368369

369370
CogVideoXPipeline.save_lora_weights(
370371
output_dir,
@@ -374,13 +375,20 @@ def save_model_hook(models, weights, output_dir):
374375
def load_model_hook(models, input_dir):
375376
transformer_ = None
376377

377-
while len(models) > 0:
378-
model = models.pop()
378+
# This is a bit of a hack but I don't know any other solution.
379+
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
380+
while len(models) > 0:
381+
model = models.pop()
379382

380-
if isinstance(model, type(unwrap_model(transformer))):
381-
transformer_ = model
382-
else:
383-
raise ValueError(f"Unexpected save model: {model.__class__}")
383+
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
384+
transformer_ = unwrap_model(model)
385+
else:
386+
raise ValueError(f"Unexpected save model: {unwrap_model(model).__class__}")
387+
else:
388+
transformer_ = CogVideoXTransformer3DModel.from_pretrained(
389+
args.pretrained_model_name_or_path, subfolder="transformer"
390+
)
391+
transformer_.add_adapter(transformer_lora_config)
384392

385393
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
386394

@@ -553,7 +561,7 @@ def collate_fn(data):
553561

554562
# We need to initialize the trackers we use, and also store our configuration.
555563
# The trackers initializes automatically on the main process.
556-
if accelerator.is_main_process:
564+
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
557565
tracker_name = args.tracker_name or "cogvideox-lora"
558566
accelerator.init_trackers(tracker_name, config=vars(args))
559567

@@ -731,7 +739,7 @@ def collate_fn(data):
731739
progress_bar.update(1)
732740
global_step += 1
733741

734-
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
742+
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
735743
if global_step % args.checkpointing_steps == 0:
736744
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
737745
if args.checkpoints_total_limit is not None:
@@ -758,12 +766,15 @@ def collate_fn(data):
758766
logger.info(f"Saved state to {save_path}")
759767

760768
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
761-
logs = {
762-
"loss": loss.detach().item(),
763-
"lr": last_lr,
764-
"gradient_norm_before_clip": gradient_norm_before_clip,
765-
"gradient_norm_after_clip": gradient_norm_after_clip,
766-
}
769+
logs = {"loss": loss.detach().item(), "lr": last_lr}
770+
# gradnorm + deepspeed: https://github.com/microsoft/DeepSpeed/issues/4555
771+
if accelerator.distributed_type != DistributedType.DEEPSPEED:
772+
logs.update(
773+
{
774+
"gradient_norm_before_clip": gradient_norm_before_clip,
775+
"gradient_norm_after_clip": gradient_norm_after_clip,
776+
}
777+
)
767778
progress_bar.set_postfix(**logs)
768779
accelerator.log(logs, step=global_step)
769780

training/cogvideox_text_to_video_sft.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
import diffusers
2626
import torch
2727
import transformers
28-
from accelerate import Accelerator, DistributedType
28+
import wandb
29+
from accelerate import Accelerator, DistributedType, init_empty_weights
2930
from accelerate.logging import get_logger
3031
from accelerate.utils import (
3132
DistributedDataParallelKwargs,
@@ -50,8 +51,6 @@
5051
from tqdm.auto import tqdm
5152
from transformers import AutoTokenizer, T5EncoderModel
5253

53-
import wandb
54-
5554

5655
from args import get_args # isort:skip
5756
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
@@ -336,31 +335,42 @@ def unwrap_model(model):
336335
def save_model_hook(models, weights, output_dir):
337336
if accelerator.is_main_process:
338337
for model in models:
339-
if isinstance(model, type(unwrap_model(transformer))):
338+
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
340339
model: CogVideoXTransformer3DModel
340+
model = unwrap_model(model)
341341
model.save_pretrained(
342342
os.path.join(output_dir, "transformer"), safe_serialization=True, max_shard_size="5GB"
343343
)
344344
else:
345345
raise ValueError(f"Unexpected save model: {model.__class__}")
346346

347347
# make sure to pop weight so that corresponding model is not saved again
348-
weights.pop()
348+
if weights:
349+
weights.pop()
349350

350351
def load_model_hook(models, input_dir):
351352
transformer_ = None
353+
init_under_meta = False
352354

353-
while len(models) > 0:
354-
model = models.pop()
355+
# This is a bit of a hack but I don't know any other solution.
356+
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
357+
while len(models) > 0:
358+
model = models.pop()
355359

356-
if isinstance(model, type(unwrap_model(transformer))):
357-
transformer_: CogVideoXTransformer3DModel = model
358-
else:
359-
raise ValueError(f"Unexpected save model: {model.__class__.__name__}")
360+
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
361+
transformer_ = unwrap_model(model)
362+
else:
363+
raise ValueError(f"Unexpected save model: {unwrap_model(model).__class__}")
364+
else:
365+
with init_empty_weights():
366+
transformer_ = CogVideoXTransformer3DModel.from_config(
367+
args.pretrained_model_name_or_path, subfolder="transformer"
368+
)
369+
init_under_meta = True
360370

361371
load_model = CogVideoXTransformer3DModel.from_pretrained(os.path.join(input_dir, "transformer"))
362372
transformer_.register_to_config(**load_model.config)
363-
transformer_.load_state_dict(load_model.state_dict())
373+
transformer_.load_state_dict(load_model.state_dict(), assign=init_under_meta)
364374
del load_model
365375

366376
# Make sure the trainable params are in float32. This is again needed since the base models
@@ -722,12 +732,15 @@ def collate_fn(data):
722732
logger.info(f"Saved state to {save_path}")
723733

724734
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
725-
logs = {
726-
"loss": loss.detach().item(),
727-
"lr": last_lr,
728-
"gradient_norm_before_clip": gradient_norm_before_clip,
729-
"gradient_norm_after_clip": gradient_norm_after_clip,
730-
}
735+
logs = {"loss": loss.detach().item(), "lr": last_lr}
736+
# gradnorm + deepspeed: https://github.com/microsoft/DeepSpeed/issues/4555
737+
if accelerator.distributed_type != DistributedType.DEEPSPEED:
738+
logs.update(
739+
{
740+
"gradient_norm_before_clip": gradient_norm_before_clip,
741+
"gradient_norm_after_clip": gradient_norm_after_clip,
742+
}
743+
)
731744
progress_bar.set_postfix(**logs)
732745
accelerator.log(logs, step=global_step)
733746

0 commit comments

Comments
 (0)