From 532f5c58a6e83a3400f82103f5854ff3f63d77d7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 20:50:42 +0900 Subject: [PATCH 01/15] formatting --- train_network.py | 229 ++++++++++++++++++++++------------------------- 1 file changed, 108 insertions(+), 121 deletions(-) diff --git a/train_network.py b/train_network.py index 2c3bb2aae..cc54be7cc 100644 --- a/train_network.py +++ b/train_network.py @@ -100,9 +100,7 @@ def generate_step_logs( if ( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None ): # tracking d*lr value of unet. - logs["lr/d*lr"] = ( - optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] - ) + logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] else: idx = 0 if not args.network_train_unet_only: @@ -115,16 +113,17 @@ def generate_step_logs( logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) - if ( - args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None - ): - logs[f"lr/d*lr/group{i}"] = ( - optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] - ) + if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None: + logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] return logs - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): train_dataset_group.verify_bucket_reso_steps(64) if val_dataset_group is not None: val_dataset_group.verify_bucket_reso_steps(64) @@ -219,7 +218,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -315,22 +314,22 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, # endregion def process_batch( - self, - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy: strategy_base.TextEncodingStrategy, - tokenize_strategy: strategy_base.TokenizeStrategy, - is_train=True, - train_text_encoder=True, - train_unet=True + self, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy: strategy_base.TextEncodingStrategy, + tokenize_strategy: strategy_base.TokenizeStrategy, + is_train=True, + train_text_encoder=True, + train_unet=True, ) -> torch.Tensor: """ Process a batch for the network @@ -397,7 +396,7 @@ def process_batch( network, weight_dtype, train_unet, - is_train=is_train + is_train=is_train, ) huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) @@ -484,7 +483,7 @@ def train(self, args): else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) - val_dataset_group = None # placeholder until validation dataset supported for arbitrary + val_dataset_group = None # placeholder until validation dataset supported for arbitrary current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -701,7 +700,7 @@ def train(self, args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) - + val_dataloader = torch.utils.data.DataLoader( val_dataset_group if val_dataset_group is not None else [], shuffle=False, @@ -900,7 +899,9 @@ def load_model_hook(models, input_dir): accelerator.print("running training / 学習開始") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}") + accelerator.print( + f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}" + ) accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") @@ -968,11 +969,11 @@ def load_model_hook(models, input_dir): "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), - "ss_validation_seed": args.validation_seed, - "ss_validation_split": args.validation_split, - "ss_max_validation_steps": args.max_validation_steps, - "ss_validate_every_n_epochs": args.validate_every_n_epochs, - "ss_validate_every_n_steps": args.validate_every_n_steps, + "ss_validation_seed": args.validation_seed, + "ss_validation_split": args.validation_split, + "ss_max_validation_steps": args.max_validation_steps, + "ss_validate_every_n_epochs": args.validate_every_n_epochs, + "ss_validate_every_n_steps": args.validate_every_n_steps, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1248,9 +1249,7 @@ def remove_model(old_ckpt_name): accelerator.log({}, step=0) validation_steps = ( - min(args.max_validation_steps, len(val_dataloader)) - if args.max_validation_steps is not None - else len(val_dataloader) + min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) ) # training loop @@ -1298,21 +1297,21 @@ def remove_model(old_ckpt_name): self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=True, - train_text_encoder=train_text_encoder, - train_unet=train_unet + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=True, + train_text_encoder=train_text_encoder, + train_unet=train_unet, ) accelerator.backward(loss) @@ -1369,32 +1368,21 @@ def remove_model(old_ckpt_name): if args.scale_weight_norms: progress_bar.set_postfix(**{**max_mean_logs, **logs}) - if is_tracking: logs = self.generate_step_logs( - args, - current_loss, - avr_loss, - lr_scheduler, - lr_descriptions, - optimizer, - keys_scaled, - mean_norm, - maximum_norm + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) accelerator.log(logs, step=global_step) # VALIDATION PER STEP should_validate_step = ( - args.validate_every_n_steps is not None - and global_step != 0 # Skip first step + args.validate_every_n_steps is not None + and global_step != 0 # Skip first step and global_step % args.validate_every_n_steps == 0 ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: val_progress_bar = tqdm( - range(validation_steps), smoothing=0, - disable=not accelerator.is_local_main_process, - desc="validation steps" + range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" ) for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: @@ -1404,27 +1392,27 @@ def remove_model(old_ckpt_name): self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, is_train=False, - train_text_encoder=False, - train_unet=False + train_text_encoder=False, + train_unet=False, ) current_loss = loss.detach().item() val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) + val_progress_bar.set_postfix({"val_avg_loss": val_step_loss_recorder.moving_average}) if is_tracking: logs = { @@ -1436,26 +1424,25 @@ def remove_model(old_ckpt_name): if is_tracking: loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average logs = { - "loss/validation/step_average": val_step_loss_recorder.moving_average, - "loss/validation/step_divergence": loss_validation_divergence, + "loss/validation/step_average": val_step_loss_recorder.moving_average, + "loss/validation/step_divergence": loss_validation_divergence, } accelerator.log(logs, step=global_step) - + if global_step >= args.max_train_steps: break # EPOCH VALIDATION should_validate_epoch = ( - (epoch + 1) % args.validate_every_n_epochs == 0 - if args.validate_every_n_epochs is not None - else True + (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True ) if should_validate_epoch and len(val_dataloader) > 0: val_progress_bar = tqdm( - range(validation_steps), smoothing=0, - disable=not accelerator.is_local_main_process, - desc="epoch validation steps" + range(validation_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="epoch validation steps", ) for val_step, batch in enumerate(val_dataloader): @@ -1466,43 +1453,43 @@ def remove_model(old_ckpt_name): self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, is_train=False, - train_text_encoder=False, - train_unet=False + train_text_encoder=False, + train_unet=False, ) current_loss = loss.detach().item() val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) + val_progress_bar.set_postfix({"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average}) if is_tracking: logs = { - "loss/validation/epoch_current": current_loss, - "epoch": epoch + 1, - "val_step": (epoch * validation_steps) + val_step + "loss/validation/epoch_current": current_loss, + "epoch": epoch + 1, + "val_step": (epoch * validation_steps) + val_step, } accelerator.log(logs, step=global_step) if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss + loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss logs = { - "loss/validation/epoch_average": avr_loss, - "loss/validation/epoch_divergence": loss_validation_divergence, - "epoch": epoch + 1 + "loss/validation/epoch_average": avr_loss, + "loss/validation/epoch_divergence": loss_validation_divergence, + "epoch": epoch + 1, } accelerator.log(logs, step=global_step) @@ -1510,7 +1497,7 @@ def remove_model(old_ckpt_name): if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} accelerator.log(logs, step=global_step) - + accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 @@ -1696,31 +1683,31 @@ def setup_parser() -> argparse.ArgumentParser: "--validation_seed", type=int, default=None, - help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する" + help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する", ) parser.add_argument( "--validation_split", type=float, default=0.0, - help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" + help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合", ) parser.add_argument( "--validate_every_n_steps", type=int, default=None, - help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます" + help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます", ) parser.add_argument( "--validate_every_n_epochs", type=int, default=None, - help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます" + help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます", ) parser.add_argument( "--max_validation_steps", type=int, default=None, - help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します" + help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します", ) return parser From 86a2f3fd262e52b3249d9f5508efe4774f1fa3ed Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 21:10:52 +0900 Subject: [PATCH 02/15] Fix gradient handling when Text Encoders are trained --- flux_train_network.py | 43 ++----------------------------------------- sd3_train_network.py | 2 +- train_network.py | 10 +++++----- 3 files changed, 8 insertions(+), 47 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 5cd1b9d51..475bd751b 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -376,9 +376,8 @@ def get_noise_pred_and_target( t5_attn_mask = None def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): - # if not args.split_mode: - # normal forward - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): + # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode + with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( img=img, @@ -390,44 +389,6 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) - """ - else: - # split forward to reduce memory usage - assert network.train_blocks == "single", "train_blocks must be single for split mode" - with accelerator.autocast(): - # move flux lower to cpu, and then move flux upper to gpu - unet.to("cpu") - clean_memory_on_device(accelerator.device) - self.flux_upper.to(accelerator.device) - - # upper model does not require grad - with torch.no_grad(): - intermediate_img, intermediate_txt, vec, pe = self.flux_upper( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - - # move flux upper back to cpu, and then move flux lower to gpu - self.flux_upper.to("cpu") - clean_memory_on_device(accelerator.device) - unet.to(accelerator.device) - - # lower model requires grad - intermediate_img.requires_grad_(True) - intermediate_txt.requires_grad_(True) - vec.requires_grad_(True) - pe.requires_grad_(True) - - with torch.set_grad_enabled(is_train and train_unet): - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) - """ - return model_pred model_pred = call_dit( diff --git a/sd3_train_network.py b/sd3_train_network.py index dcf497f53..2f4579492 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -345,7 +345,7 @@ def get_noise_pred_and_target( t5_attn_mask = None # call model - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): + with torch.set_grad_enabled(is_train), accelerator.autocast(): # TODO support attention mask model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) diff --git a/train_network.py b/train_network.py index cc54be7cc..6f1652fd9 100644 --- a/train_network.py +++ b/train_network.py @@ -232,7 +232,7 @@ def get_noise_pred_and_target( t.requires_grad_(True) # Predict the noise residual - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): + with torch.set_grad_enabled(is_train), accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, @@ -1405,8 +1405,8 @@ def remove_model(old_ckpt_name): text_encoding_strategy, tokenize_strategy, is_train=False, - train_text_encoder=False, - train_unet=False, + train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True + train_unet=train_unet, ) current_loss = loss.detach().item() @@ -1466,8 +1466,8 @@ def remove_model(old_ckpt_name): text_encoding_strategy, tokenize_strategy, is_train=False, - train_text_encoder=False, - train_unet=False, + train_text_encoder=train_text_encoder, + train_unet=train_unet, ) current_loss = loss.detach().item() From b6a309321675b5d0a59b776ffb4d0ecdd3d28ec2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 21:22:11 +0900 Subject: [PATCH 03/15] call optimizer eval/train fn before/after validation --- train_network.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/train_network.py b/train_network.py index 6f1652fd9..e735c582d 100644 --- a/train_network.py +++ b/train_network.py @@ -1381,6 +1381,8 @@ def remove_model(old_ckpt_name): and global_step % args.validate_every_n_steps == 0 ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: + optimizer_eval_fn() + val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" ) @@ -1429,6 +1431,8 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + optimizer_train_fn() + if global_step >= args.max_train_steps: break @@ -1438,6 +1442,8 @@ def remove_model(old_ckpt_name): ) if should_validate_epoch and len(val_dataloader) > 0: + optimizer_eval_fn() + val_progress_bar = tqdm( range(validation_steps), smoothing=0, @@ -1493,6 +1499,8 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + optimizer_train_fn() + # END OF EPOCH if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} From 29f31d005f12a08650389164fa9c60504928d451 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 21:35:43 +0900 Subject: [PATCH 04/15] add network.train()/eval() for validation --- train_network.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index e735c582d..9b8036f8b 100644 --- a/train_network.py +++ b/train_network.py @@ -1276,7 +1276,7 @@ def remove_model(old_ckpt_name): metadata["ss_epoch"] = str(epoch + 1) - accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) + accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # network.train() is called here # TRAINING skipped_dataloader = None @@ -1382,6 +1382,7 @@ def remove_model(old_ckpt_name): ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() + accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" @@ -1432,6 +1433,7 @@ def remove_model(old_ckpt_name): accelerator.log(logs, step=global_step) optimizer_train_fn() + accelerator.unwrap_model(network).train() if global_step >= args.max_train_steps: break @@ -1443,6 +1445,7 @@ def remove_model(old_ckpt_name): if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() + accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( range(validation_steps), @@ -1500,6 +1503,7 @@ def remove_model(old_ckpt_name): accelerator.log(logs, step=global_step) optimizer_train_fn() + accelerator.unwrap_model(network).train() # END OF EPOCH if is_tracking: From 0750859133eec7858052cd3f79106113fa786e94 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 21:56:59 +0900 Subject: [PATCH 05/15] validation: Implement timestep-based validation processing --- sd3_train_network.py | 1 + train_network.py | 167 +++++++++++++++++++++++++------------------ 2 files changed, 100 insertions(+), 68 deletions(-) diff --git a/sd3_train_network.py b/sd3_train_network.py index 2f4579492..d4f131252 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -446,6 +446,7 @@ def forward(hidden_states): prepare_fp8(text_encoder, weight_dtype) def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + # TODO consider validation # drop cached text encoder outputs text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: diff --git a/train_network.py b/train_network.py index 9b8036f8b..a63e9d1e9 100644 --- a/train_network.py +++ b/train_network.py @@ -9,6 +9,7 @@ import time import json from multiprocessing import Value +import numpy as np import toml from tqdm import tqdm @@ -1248,10 +1249,6 @@ def remove_model(old_ckpt_name): # log empty object to commit the sample images to wandb accelerator.log({}, step=0) - validation_steps = ( - min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - ) - # training loop if initial_step > 0: # only if skip_until_initial_step is specified for skip_epoch in range(epoch_to_start): # skip epochs @@ -1270,6 +1267,17 @@ def remove_model(old_ckpt_name): clean_memory_on_device(accelerator.device) + validation_steps = ( + min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + ) + NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable + min_timestep = 0 if args.min_timestep is None else args.min_timestep + max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep + validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1] + validation_total_steps = validation_steps * len(validation_timesteps) + original_args_min_timestep = args.min_timestep + original_args_max_timestep = args.max_timestep + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 @@ -1385,44 +1393,55 @@ def remove_model(old_ckpt_name): accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( - range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" + range(validation_total_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="validation steps", ) + val_ts_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - - loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=False, - train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True - train_unet=train_unet, - ) - - current_loss = loss.detach().item() - val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) - val_progress_bar.update(1) - val_progress_bar.set_postfix({"val_avg_loss": val_step_loss_recorder.moving_average}) - - if is_tracking: - logs = { - "loss/validation/step_current": current_loss, - "val_step": (epoch * validation_steps) + val_step, - } - accelerator.log(logs, step=global_step) + for timestep in validation_timesteps: + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + + args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep + + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False, + train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True + train_unet=train_unet, + ) + + current_loss = loss.detach().item() + val_step_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix( + {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} + ) + + if is_tracking: + logs = { + "loss/validation/step_current": current_loss, + "val_step": (epoch * validation_total_steps) + val_ts_step, + } + accelerator.log(logs, step=global_step) + + val_ts_step += 1 if is_tracking: loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average @@ -1432,6 +1451,8 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + args.min_timestep = original_args_min_timestep + args.max_timestep = original_args_max_timestep optimizer_train_fn() accelerator.unwrap_model(network).train() @@ -1448,49 +1469,57 @@ def remove_model(old_ckpt_name): accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( - range(validation_steps), + range(validation_total_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="epoch validation steps", ) + val_ts_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + for timestep in validation_timesteps: + args.min_timestep = args.max_timestep = timestep - loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=False, - train_text_encoder=train_text_encoder, - train_unet=train_unet, - ) + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False, + train_text_encoder=train_text_encoder, + train_unet=train_unet, + ) - current_loss = loss.detach().item() - val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) - val_progress_bar.update(1) - val_progress_bar.set_postfix({"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average}) + current_loss = loss.detach().item() + val_epoch_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix( + {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep} + ) - if is_tracking: - logs = { - "loss/validation/epoch_current": current_loss, - "epoch": epoch + 1, - "val_step": (epoch * validation_steps) + val_step, - } - accelerator.log(logs, step=global_step) + if is_tracking: + logs = { + "loss/validation/epoch_current": current_loss, + "epoch": epoch + 1, + "val_step": (epoch * validation_total_steps) + val_ts_step, + } + accelerator.log(logs, step=global_step) + + val_ts_step += 1 if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average @@ -1502,6 +1531,8 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + args.min_timestep = original_args_min_timestep + args.max_timestep = original_args_max_timestep optimizer_train_fn() accelerator.unwrap_model(network).train() From 45ec02b2a8b5eb5af8f5b4877381dc4dcc596cb9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 22:10:38 +0900 Subject: [PATCH 06/15] use same noise for every validation --- flux_train_network.py | 1 - train_network.py | 6 ++++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/flux_train_network.py b/flux_train_network.py index aab025735..475bd751b 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -377,7 +377,6 @@ def get_noise_pred_and_target( def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode - with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( diff --git a/train_network.py b/train_network.py index a63e9d1e9..f0deb67ab 100644 --- a/train_network.py +++ b/train_network.py @@ -1391,6 +1391,8 @@ def remove_model(old_ckpt_name): if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() + rng_state = torch.get_rng_state() + torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1451,6 +1453,7 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + torch.set_rng_state(rng_state) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() @@ -1467,6 +1470,8 @@ def remove_model(old_ckpt_name): if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() accelerator.unwrap_model(network).eval() + rng_state = torch.get_rng_state() + torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1531,6 +1536,7 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + torch.set_rng_state(rng_state) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() From c5b803ce94bd70812e6979ac7b986a769659b14e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Feb 2025 21:59:09 +0900 Subject: [PATCH 07/15] rng state management: Implement functions to get and set RNG states for consistent validation --- train_network.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index f0deb67ab..b3c7ff524 100644 --- a/train_network.py +++ b/train_network.py @@ -1278,6 +1278,31 @@ def remove_model(old_ckpt_name): original_args_min_timestep = args.min_timestep original_args_max_timestep = args.max_timestep + def get_rng_state() -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + cpu_rng_state = torch.get_rng_state() + if accelerator.device.type == "cuda": + gpu_rng_state = torch.cuda.get_rng_state() + elif accelerator.device.type == "xpu": + gpu_rng_state = torch.xpu.get_rng_state() + elif accelerator.device.type == "mps": + gpu_rng_state = torch.cuda.get_rng_state() + else: + gpu_rng_state = None + python_rng_state = random.getstate() + return (cpu_rng_state, gpu_rng_state, python_rng_state) + + def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): + cpu_rng_state, gpu_rng_state, python_rng_state = rng_states + torch.set_rng_state(cpu_rng_state) + if gpu_rng_state is not None: + if accelerator.device.type == "cuda": + torch.cuda.set_rng_state(gpu_rng_state) + elif accelerator.device.type == "xpu": + torch.xpu.set_rng_state(gpu_rng_state) + elif accelerator.device.type == "mps": + torch.cuda.set_rng_state(gpu_rng_state) + random.setstate(python_rng_state) + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 @@ -1391,7 +1416,7 @@ def remove_model(old_ckpt_name): if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_state = torch.get_rng_state() + rng_states = get_rng_state() torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( @@ -1453,7 +1478,7 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) - torch.set_rng_state(rng_state) + set_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() @@ -1470,7 +1495,7 @@ def remove_model(old_ckpt_name): if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_state = torch.get_rng_state() + rng_states = get_rng_state() torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( @@ -1536,7 +1561,7 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) - torch.set_rng_state(rng_state) + set_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() From a24db1d532a95cc9dd91aba25a06b8eb58db5cff Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Feb 2025 22:02:42 +0900 Subject: [PATCH 08/15] fix: validation timestep generation fails on SD/SDXL training --- library/train_util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 37ed0a994..01fa64674 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5935,7 +5935,10 @@ def save_sd_model_on_train_end_common( def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor: - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") + if min_timestep < max_timestep: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") + else: + timesteps = torch.full((b_size,), max_timestep, device="cpu") timesteps = timesteps.long().to(device) return timesteps From 0911683717e439676bba758a5f7a29356984966c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Feb 2025 20:53:49 +0900 Subject: [PATCH 09/15] set python random state --- train_network.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/train_network.py b/train_network.py index b3c7ff524..083e5993d 100644 --- a/train_network.py +++ b/train_network.py @@ -1278,7 +1278,7 @@ def remove_model(old_ckpt_name): original_args_min_timestep = args.min_timestep original_args_max_timestep = args.max_timestep - def get_rng_state() -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + def switch_rng_state(seed:int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: cpu_rng_state = torch.get_rng_state() if accelerator.device.type == "cuda": gpu_rng_state = torch.cuda.get_rng_state() @@ -1289,9 +1289,13 @@ def get_rng_state() -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple else: gpu_rng_state = None python_rng_state = random.getstate() + + torch.manual_seed(seed) + random.seed(seed) + return (cpu_rng_state, gpu_rng_state, python_rng_state) - def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): + def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): cpu_rng_state, gpu_rng_state, python_rng_state = rng_states torch.set_rng_state(cpu_rng_state) if gpu_rng_state is not None: @@ -1416,8 +1420,7 @@ def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor] if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_states = get_rng_state() - torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) + rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1478,7 +1481,7 @@ def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor] } accelerator.log(logs, step=global_step) - set_rng_state(rng_states) + restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() @@ -1495,8 +1498,7 @@ def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor] if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_states = get_rng_state() - torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) + rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1561,7 +1563,7 @@ def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor] } accelerator.log(logs, step=global_step) - set_rng_state(rng_states) + restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() From 344845b42941b48956dce94d614fbf32e900c70e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Feb 2025 21:25:40 +0900 Subject: [PATCH 10/15] fix: validation with block swap --- flux_train_network.py | 14 ++++++++++++-- sd3_train_network.py | 19 ++++++++++++++----- train_network.py | 18 +++++++++++------- 3 files changed, 37 insertions(+), 14 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 475bd751b..e97dfc5b8 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -36,7 +36,12 @@ def __init__(self): self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): super().assert_extra_args(args, train_dataset_group, val_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -341,7 +346,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -507,6 +512,11 @@ def forward(hidden_states): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module ) -> torch.nn.Module: diff --git a/sd3_train_network.py b/sd3_train_network.py index d4f131252..216d93c58 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -26,7 +26,12 @@ def __init__(self): super().__init__() self.sample_prompts_te_outputs = None - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): # super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -317,7 +322,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -445,15 +450,19 @@ def forward(hidden_states): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) - def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): - # TODO consider validation - # drop cached text encoder outputs + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True): + # drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) batch["text_encoder_outputs_list"] = text_encoder_outputs_list + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module ) -> torch.nn.Module: diff --git a/train_network.py b/train_network.py index 083e5993d..49013c708 100644 --- a/train_network.py +++ b/train_network.py @@ -309,7 +309,10 @@ def prepare_unet_with_accelerator( ) -> torch.nn.Module: return accelerator.prepare(unet) - def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True): + pass + + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): pass # endregion @@ -1278,7 +1281,7 @@ def remove_model(old_ckpt_name): original_args_min_timestep = args.min_timestep original_args_max_timestep = args.max_timestep - def switch_rng_state(seed:int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: cpu_rng_state = torch.get_rng_state() if accelerator.device.type == "cuda": gpu_rng_state = torch.cuda.get_rng_state() @@ -1330,8 +1333,8 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen with accelerator.accumulate(training_model): on_step_start_for_network(text_encoder, unet) - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + # preprocess batch for each model + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) loss = self.process_batch( batch, @@ -1434,8 +1437,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen break for timestep in validation_timesteps: - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep @@ -1471,6 +1473,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen } accelerator.log(logs, step=global_step) + self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 if is_tracking: @@ -1516,7 +1519,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen args.min_timestep = args.max_timestep = timestep # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) loss = self.process_batch( batch, @@ -1551,6 +1554,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen } accelerator.log(logs, step=global_step) + self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 if is_tracking: From 177203818a024329efa74640a588674323363373 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Feb 2025 21:42:46 +0900 Subject: [PATCH 11/15] fix: unpause training progress bar after vaidation --- train_network.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train_network.py b/train_network.py index 49013c708..8bfb19258 100644 --- a/train_network.py +++ b/train_network.py @@ -1489,6 +1489,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen args.max_timestep = original_args_max_timestep optimizer_train_fn() accelerator.unwrap_model(network).train() + progress_bar.unpause() if global_step >= args.max_train_steps: break @@ -1572,6 +1573,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen args.max_timestep = original_args_max_timestep optimizer_train_fn() accelerator.unwrap_model(network).train() + progress_bar.unpause() # END OF EPOCH if is_tracking: From cd80752175c663ede2cb7995da652ed5f5f7f749 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Feb 2025 21:42:58 +0900 Subject: [PATCH 12/15] fix: remove unused parameter 'accelerator' from encode_images_to_latents method --- flux_train_network.py | 2 +- sd3_train_network.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index e97dfc5b8..def441559 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -328,7 +328,7 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): diff --git a/sd3_train_network.py b/sd3_train_network.py index 216d93c58..cdb7aa4e3 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -304,7 +304,7 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): From 76b761943b5166f496aa1cb8ffbcc2d04469346a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Feb 2025 21:53:57 +0900 Subject: [PATCH 13/15] fix: simplify validation step condition in NetworkTrainer --- train_network.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/train_network.py b/train_network.py index 8bfb19258..99c58f49f 100644 --- a/train_network.py +++ b/train_network.py @@ -1414,12 +1414,9 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen ) accelerator.log(logs, step=global_step) - # VALIDATION PER STEP - should_validate_step = ( - args.validate_every_n_steps is not None - and global_step != 0 # Skip first step - and global_step % args.validate_every_n_steps == 0 - ) + # VALIDATION PER STEP: global_step is already incremented + # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ... + should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0 if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() From 4a369961346ca153a370728247449978d8a33415 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 18 Feb 2025 22:05:08 +0900 Subject: [PATCH 14/15] modify log step calculation --- train_network.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/train_network.py b/train_network.py index 47c4bb56e..93558da45 100644 --- a/train_network.py +++ b/train_network.py @@ -1464,11 +1464,10 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen ) if is_tracking: - logs = { - "loss/validation/step_current": current_loss, - "val_step": (epoch * validation_total_steps) + val_ts_step, - } - accelerator.log(logs, step=global_step) + logs = {"loss/validation/step_current": current_loss} + accelerator.log( + logs, step=global_step + val_ts_step + ) # a bit weird to log with global_step + val_ts_step self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 @@ -1545,25 +1544,20 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen ) if is_tracking: - logs = { - "loss/validation/epoch_current": current_loss, - "epoch": epoch + 1, - "val_step": (epoch * validation_total_steps) + val_ts_step, - } - accelerator.log(logs, step=global_step) + logs = {"loss/validation/epoch_current": current_loss} + accelerator.log(logs, step=global_step + val_ts_step) self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average + loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average logs = { "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, - "epoch": epoch + 1, } - accelerator.log(logs, step=global_step) + accelerator.log(logs, step=epoch + 1) restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep @@ -1574,8 +1568,8 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen # END OF EPOCH if is_tracking: - logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} - accelerator.log(logs, step=global_step) + logs = {"loss/epoch_average": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() From efb2a128cd0d2c6340a21bf544e77853a20b3453 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 21 Feb 2025 22:07:35 +0900 Subject: [PATCH 15/15] fix wandb val logging --- library/train_util.py | 57 +++++++++++++++------------------ train_network.py | 73 ++++++++++++++++++++++++++++++++----------- 2 files changed, 80 insertions(+), 50 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 258701982..1f591c422 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -13,17 +13,7 @@ import shutil import time import typing -from typing import ( - Any, - Callable, - Dict, - List, - NamedTuple, - Optional, - Sequence, - Tuple, - Union -) +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob import math @@ -146,12 +136,13 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" + def split_train_val( - paths: List[str], + paths: List[str], sizes: List[Optional[Tuple[int, int]]], - is_training_dataset: bool, - validation_split: float, - validation_seed: int | None + is_training_dataset: bool, + validation_split: float, + validation_seed: int | None, ) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]: """ Split the dataset into train and validation @@ -1842,7 +1833,7 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): class DreamBoothDataset(BaseDataset): IMAGE_INFO_CACHE_FILE = "metadata_cache.json" - # The is_training_dataset defines the type of dataset, training or validation + # The is_training_dataset defines the type of dataset, training or validation # if is_training_dataset is True -> training dataset # if is_training_dataset is False -> validation dataset def __init__( @@ -1981,29 +1972,25 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") # We want to create a training and validation split. This should be improved in the future - # to allow a clearer distinction between training and validation. This can be seen as a + # to allow a clearer distinction between training and validation. This can be seen as a # short-term solution to limit what is necessary to implement validation datasets - # + # # We split the dataset for the subset based on if we are doing a validation split - # The self.is_training_dataset defines the type of dataset, training or validation + # The self.is_training_dataset defines the type of dataset, training or validation # if self.is_training_dataset is True -> training dataset # if self.is_training_dataset is False -> validation dataset if self.validation_split > 0.0: - # For regularization images we do not want to split this dataset. + # For regularization images we do not want to split this dataset. if subset.is_reg is True: # Skip any validation dataset for regularization images if self.is_training_dataset is False: img_paths = [] sizes = [] - # Otherwise the img_paths remain as original img_paths and no split + # Otherwise the img_paths remain as original img_paths and no split # required for training images dataset of regularization images else: img_paths, sizes = split_train_val( - img_paths, - sizes, - self.is_training_dataset, - self.validation_split, - self.validation_seed + img_paths, sizes, self.is_training_dataset, self.validation_split, self.validation_seed ) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") @@ -2373,7 +2360,7 @@ def __init__( bucket_no_upscale: bool, debug_dataset: bool, validation_split: float, - validation_seed: Optional[int], + validation_seed: Optional[int], ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2431,9 +2418,9 @@ def __init__( self.image_data = self.dreambooth_dataset_delegate.image_data self.batch_size = batch_size self.num_train_images = self.dreambooth_dataset_delegate.num_train_images - self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images self.validation_split = validation_split - self.validation_seed = validation_seed + self.validation_seed = validation_seed # assert all conditioning data exists missing_imgs = [] @@ -5952,7 +5939,9 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor return timesteps -def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: +def get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents: torch.FloatTensor +) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: @@ -6444,7 +6433,7 @@ def sample_image_inference( wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption -def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): +def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): """ Initialize experiment trackers with tracker specific behaviors """ @@ -6461,13 +6450,17 @@ def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tr ) if "wandb" in [tracker.name for tracker in accelerator.trackers]: - import wandb + import wandb + wandb_tracker = accelerator.get_tracker("wandb", unwrap=True) # Define specific metrics to handle validation and epochs "steps" wandb_tracker.define_metric("epoch", hidden=True) wandb_tracker.define_metric("val_step", hidden=True) + wandb_tracker.define_metric("global_step", hidden=True) + + # endregion diff --git a/train_network.py b/train_network.py index 93558da45..ab5483deb 100644 --- a/train_network.py +++ b/train_network.py @@ -119,6 +119,45 @@ def generate_step_logs( return logs + def step_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int): + self.accelerator_logging(accelerator, logs, global_step, global_step, epoch) + + def epoch_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int): + self.accelerator_logging(accelerator, logs, epoch, global_step, epoch) + + def val_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int, val_step: int): + self.accelerator_logging(accelerator, logs, global_step + val_step, global_step, epoch, val_step) + + def accelerator_logging( + self, accelerator: Accelerator, logs: dict, step_value: int, global_step: int, epoch: int, val_step: Optional[int] = None + ): + """ + step_value is for tensorboard, other values are for wandb + """ + tensorboard_tracker = None + wandb_tracker = None + other_trackers = [] + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + tensorboard_tracker = accelerator.get_tracker("tensorboard") + elif tracker.name == "wandb": + wandb_tracker = accelerator.get_tracker("wandb") + else: + other_trackers.append(accelerator.get_tracker(tracker.name)) + + if tensorboard_tracker is not None: + tensorboard_tracker.log(logs, step=step_value) + + if wandb_tracker is not None: + logs["global_step"] = global_step + logs["epoch"] = epoch + if val_step is not None: + logs["val_step"] = val_step + wandb_tracker.log(logs) + + for tracker in other_trackers: + tracker.log(logs, step=step_value) + def assert_extra_args( self, args, @@ -1412,7 +1451,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen logs = self.generate_step_logs( args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) - accelerator.log(logs, step=global_step) + self.step_logging(accelerator, logs, global_step, epoch + 1) # VALIDATION PER STEP: global_step is already incremented # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ... @@ -1428,7 +1467,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen disable=not accelerator.is_local_main_process, desc="validation steps", ) - val_ts_step = 0 + val_timesteps_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break @@ -1457,20 +1496,18 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen ) current_loss = loss.detach().item() - val_step_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) + val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} ) - if is_tracking: - logs = {"loss/validation/step_current": current_loss} - accelerator.log( - logs, step=global_step + val_ts_step - ) # a bit weird to log with global_step + val_ts_step + # if is_tracking: + # logs = {f"loss/validation/step_current_{timestep}": current_loss} + # self.val_logging(accelerator, logs, global_step, epoch + 1, val_step) self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - val_ts_step += 1 + val_timesteps_step += 1 if is_tracking: loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average @@ -1478,7 +1515,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen "loss/validation/step_average": val_step_loss_recorder.moving_average, "loss/validation/step_divergence": loss_validation_divergence, } - accelerator.log(logs, step=global_step) + self.step_logging(accelerator, logs, global_step, epoch=epoch + 1) restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep @@ -1507,7 +1544,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen desc="epoch validation steps", ) - val_ts_step = 0 + val_timesteps_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break @@ -1537,18 +1574,18 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen ) current_loss = loss.detach().item() - val_epoch_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) + val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep} ) - if is_tracking: - logs = {"loss/validation/epoch_current": current_loss} - accelerator.log(logs, step=global_step + val_ts_step) + # if is_tracking: + # logs = {f"loss/validation/epoch_current_{timestep}": current_loss} + # self.val_logging(accelerator, logs, global_step, epoch + 1, val_step) self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - val_ts_step += 1 + val_timesteps_step += 1 if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average @@ -1557,7 +1594,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, } - accelerator.log(logs, step=epoch + 1) + self.epoch_logging(accelerator, logs, global_step, epoch + 1) restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep @@ -1569,7 +1606,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen # END OF EPOCH if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) + self.epoch_logging(accelerator, logs, global_step, epoch + 1) accelerator.wait_for_everyone()