Skip to content

Commit 0a78cf1

Browse files
a-r-r-o-w963658029
andcommitted
update
Co-Authored-By: yuan-shenghai <[email protected]>
1 parent 935d460 commit 0a78cf1

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

examples/cogvideo/train_cogvideox_image_to_video_lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,7 +1461,7 @@ def collate_fn(examples):
14611461
progress_bar.update(1)
14621462
global_step += 1
14631463

1464-
if accelerator.is_main_process:
1464+
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
14651465
if global_step % args.checkpointing_steps == 0:
14661466
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
14671467
if args.checkpoints_total_limit is not None:
@@ -1494,7 +1494,7 @@ def collate_fn(examples):
14941494
if global_step >= args.max_train_steps:
14951495
break
14961496

1497-
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
1497+
if accelerator.is_main_process:
14981498
if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
14991499
# Create pipeline
15001500
pipe = CogVideoXImageToVideoPipeline.from_pretrained(

examples/cogvideo/train_cogvideox_lora.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import torch
2626
import torchvision.transforms as TT
2727
import transformers
28-
from accelerate import Accelerator
28+
from accelerate import Accelerator, DistributedType
2929
from accelerate.logging import get_logger
3030
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
3131
from huggingface_hub import create_repo, upload_folder
@@ -1211,7 +1211,7 @@ def load_model_hook(models, input_dir):
12111211
)
12121212
use_deepspeed_scheduler = (
12131213
accelerator.state.deepspeed_plugin is not None
1214-
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
1214+
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
12151215
)
12161216

12171217
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
@@ -1456,7 +1456,7 @@ def collate_fn(examples):
14561456
progress_bar.update(1)
14571457
global_step += 1
14581458

1459-
if accelerator.is_main_process:
1459+
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
14601460
if global_step % args.checkpointing_steps == 0:
14611461
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
14621462
if args.checkpoints_total_limit is not None:

0 commit comments

Comments
 (0)