Skip to content

Commit 7f654e3

Browse files
authored
[misc] Polish V1 training code (#469)
1 parent 8631c1b commit 7f654e3

File tree

13 files changed

+46
-60
lines changed

13 files changed

+46
-60
lines changed

fastvideo/distill.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def main(args):
242242
noise_random_generator = None
243243

244244
# Handle the repository creation
245-
if rank <= 0 and args.output_dir is not None:
245+
if rank == 0 and args.output_dir is not None:
246246
os.makedirs(args.output_dir, exist_ok=True)
247247

248248
# For mixed precision training we cast all non-trainable weights to half-precision
@@ -391,7 +391,7 @@ def main(args):
391391
len(train_dataloader) / args.gradient_accumulation_steps * args.sp_size / args.train_sp_batch_size)
392392
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
393393

394-
if rank <= 0:
394+
if rank == 0:
395395
project = args.tracker_project_name or "fastvideo"
396396
wandb.init(project=project, config=args)
397397

@@ -493,7 +493,7 @@ def get_num_phases(multi_phased_distill_schedule, step):
493493
"phases": num_phases,
494494
})
495495
progress_bar.update(1)
496-
if rank <= 0:
496+
if rank == 0:
497497
wandb.log(
498498
{
499499
"train_loss": loss,

fastvideo/distill_adv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def main(args):
296296
noise_random_generator = None
297297

298298
# Handle the repository creation
299-
if rank <= 0 and args.output_dir is not None:
299+
if rank == 0 and args.output_dir is not None:
300300
os.makedirs(args.output_dir, exist_ok=True)
301301

302302
# For mixed precision training we cast all non-trainable weights to half-precision
@@ -462,7 +462,7 @@ def main(args):
462462
len(train_dataloader) / args.gradient_accumulation_steps * args.sp_size / args.train_sp_batch_size)
463463
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
464464

465-
if rank <= 0:
465+
if rank == 0:
466466
project = args.tracker_project_name or "fastvideo"
467467
wandb.init(project=project, config=args)
468468

@@ -559,7 +559,7 @@ def get_num_phases(multi_phased_distill_schedule, step):
559559
"step_time": f"{step_time:.2f}s",
560560
})
561561
progress_bar.update(1)
562-
if rank <= 0:
562+
if rank == 0:
563563
wandb.log(
564564
{
565565
"generator_loss": generator_loss,

fastvideo/sample/sample_t2v_hunyuan_hf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def inference(args):
8686
num_inference_steps=args.num_inference_steps,
8787
generator=generator,
8888
).frames
89-
if nccl_info.global_rank <= 0:
89+
if nccl_info.global_rank == 0:
9090
os.makedirs(args.output_path, exist_ok=True)
9191
suffix = prompt.split(".")[0]
9292
export_to_video(
@@ -107,7 +107,7 @@ def inference(args):
107107
generator=generator,
108108
).frames
109109

110-
if nccl_info.global_rank <= 0:
110+
if nccl_info.global_rank == 0:
111111
export_to_video(videos[0], args.output_path + ".mp4", fps=24)
112112

113113

fastvideo/sample/sample_t2v_mochi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def main(args):
9494
guidance_scale=args.guidance_scale,
9595
generator=generator,
9696
).frames
97-
if nccl_info.global_rank <= 0:
97+
if nccl_info.global_rank == 0:
9898
os.makedirs(args.output_path, exist_ok=True)
9999
suffix = prompt.split(".")[0]
100100
export_to_video(
@@ -116,7 +116,7 @@ def main(args):
116116
generator=generator,
117117
).frames
118118

119-
if nccl_info.global_rank <= 0:
119+
if nccl_info.global_rank == 0:
120120
export_to_video(videos[0], args.output_path + ".mp4", fps=30)
121121

122122

fastvideo/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def main(args):
185185
noise_random_generator = None
186186

187187
# Handle the repository creation
188-
if rank <= 0 and args.output_dir is not None:
188+
if rank == 0 and args.output_dir is not None:
189189
os.makedirs(args.output_dir, exist_ok=True)
190190

191191
# For mixed precision training we cast all non-trainable weights to half-precision
@@ -316,7 +316,7 @@ def main(args):
316316
len(train_dataloader) / args.gradient_accumulation_steps * args.sp_size / args.train_sp_batch_size)
317317
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
318318

319-
if rank <= 0:
319+
if rank == 0:
320320
project = args.tracker_project_name or "fastvideo"
321321
wandb.init(project=project, config=args)
322322

@@ -393,7 +393,7 @@ def main(args):
393393
"grad_norm": grad_norm,
394394
})
395395
progress_bar.update(1)
396-
if rank <= 0:
396+
if rank == 0:
397397
wandb.log(
398398
{
399399
"train_loss": loss,

fastvideo/utils/checkpoint.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def save_checkpoint_optimizer(model, optimizer, rank, output_dir, step, discrimi
3232
save_dir = os.path.join(output_dir, f"checkpoint-{step}")
3333
os.makedirs(save_dir, exist_ok=True)
3434
# save using safetensors
35-
if rank <= 0 and not discriminator:
35+
if rank == 0 and not discriminator:
3636
weight_path = os.path.join(save_dir, "diffusion_pytorch_model.safetensors")
3737
save_file(cpu_state, weight_path)
3838
config_dict = dict(model.config)
@@ -60,7 +60,7 @@ def save_checkpoint(transformer, rank, output_dir, step):
6060
):
6161
cpu_state = transformer.state_dict()
6262
# todo move to get_state_dict
63-
if rank <= 0:
63+
if rank == 0:
6464
save_dir = os.path.join(output_dir, f"checkpoint-{step}")
6565
os.makedirs(save_dir, exist_ok=True)
6666
# save using safetensors
@@ -98,7 +98,7 @@ def save_checkpoint_generator_discriminator(
9898
hf_weight_dir = os.path.join(save_dir, "hf_weights")
9999
os.makedirs(hf_weight_dir, exist_ok=True)
100100
# save using safetensors
101-
if rank <= 0:
101+
if rank == 0:
102102
config_dict = dict(model.config)
103103
config_path = os.path.join(hf_weight_dir, "config.json")
104104
# save dict as json
@@ -139,7 +139,7 @@ def save_checkpoint_generator_discriminator(
139139
optim_state = FSDP.optim_state_dict(discriminator, discriminator_optimizer)
140140
model_state = discriminator.state_dict()
141141
state_dict = {"optimizer": optim_state, "model": model_state}
142-
if rank <= 0:
142+
if rank == 0:
143143
discriminator_fsdp_state_fil = os.path.join(discriminator_fsdp_state_dir, "discriminator_state.pt")
144144
torch.save(state_dict, discriminator_fsdp_state_fil)
145145

@@ -178,7 +178,7 @@ def load_full_state_model(model, optimizer, checkpoint_file, rank):
178178
):
179179
discriminator_state = torch.load(checkpoint_file)
180180
model_state = discriminator_state["model"]
181-
if rank <= 0:
181+
if rank == 0:
182182
optim_state = discriminator_state["optimizer"]
183183
else:
184184
optim_state = None
@@ -241,7 +241,7 @@ def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step, pipelin
241241
optimizer,
242242
)
243243

244-
if rank <= 0:
244+
if rank == 0:
245245
save_dir = os.path.join(output_dir, f"lora-checkpoint-{step}")
246246
os.makedirs(save_dir, exist_ok=True)
247247

fastvideo/v1/dataset/latent_datasets.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -107,23 +107,3 @@ def latent_collate_function(batch):
107107
prompt_attention_masks = torch.stack(prompt_attention_masks, dim=0)
108108
latents = torch.stack(latent_list, dim=0)
109109
return latents, prompt_embeds, latent_attn_mask, prompt_attention_masks
110-
111-
112-
if __name__ == "__main__":
113-
dataset = LatentDataset("data/Mochi-Synthetic-Data/merge.txt",
114-
num_latent_t=28,
115-
cfg_rate=0.0)
116-
dataloader = torch.utils.data.DataLoader(dataset,
117-
batch_size=2,
118-
shuffle=False,
119-
collate_fn=latent_collate_function)
120-
for latent, prompt_embed, latent_attn_mask, prompt_attention_mask in dataloader:
121-
print(
122-
latent.shape,
123-
prompt_embed.shape,
124-
latent_attn_mask.shape,
125-
prompt_attention_mask.shape,
126-
)
127-
import pdb
128-
129-
pdb.set_trace()

fastvideo/v1/distributed/parallel_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -747,8 +747,8 @@ def set_custom_all_reduce(enable: bool):
747747

748748

749749
def init_distributed_environment(
750-
world_size: int = -1,
751-
rank: int = -1,
750+
world_size: int = 1,
751+
rank: int = 0,
752752
distributed_init_method: str = "env://",
753753
local_rank: int = -1,
754754
backend: str = "nccl",

fastvideo/v1/pipelines/composed_pipeline_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def maybe_init_distributed_environment(self, fastvideo_args: FastVideoArgs):
188188

189189
if local_rank == -1 or world_size == -1 or rank == -1:
190190
raise ValueError(
191-
"Local rank, world size, and rank must be set. Use torchrun to launch the script."
191+
"Local rank, world size, and rank must be set. Use torchrun to launch the script or pass rank to the worker process."
192192
)
193193

194194
torch.cuda.set_device(local_rank)

fastvideo/v1/training/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .training_pipeline import TrainingPipeline
2+
from .wan_training_pipeline import WanTrainingPipeline
3+
4+
__all__ = ["TrainingPipeline", "WanTrainingPipeline"]

0 commit comments

Comments
 (0)