diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 125368841fa8..b8c946fa29e9 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -842,7 +842,7 @@ def collate_fn(examples): resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar = tqdm(range(0, args.max_train_steps), initial=global_step, disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") for epoch in range(first_epoch, args.num_train_epochs): @@ -851,8 +851,6 @@ def collate_fn(examples): for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) continue with accelerator.accumulate(unet):