diff --git a/flux_train_network.py b/flux_train_network.py index 26503df1f..def441559 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -381,8 +381,7 @@ def get_noise_pred_and_target( t5_attn_mask = None def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): - # if not args.split_mode: - # normal forward + # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( @@ -395,44 +394,6 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) - """ - else: - # split forward to reduce memory usage - assert network.train_blocks == "single", "train_blocks must be single for split mode" - with accelerator.autocast(): - # move flux lower to cpu, and then move flux upper to gpu - unet.to("cpu") - clean_memory_on_device(accelerator.device) - self.flux_upper.to(accelerator.device) - - # upper model does not require grad - with torch.no_grad(): - intermediate_img, intermediate_txt, vec, pe = self.flux_upper( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - - # move flux upper back to cpu, and then move flux lower to gpu - self.flux_upper.to("cpu") - clean_memory_on_device(accelerator.device) - unet.to(accelerator.device) - - # lower model requires grad - intermediate_img.requires_grad_(True) - intermediate_txt.requires_grad_(True) - vec.requires_grad_(True) - pe.requires_grad_(True) - - with torch.set_grad_enabled(is_train and train_unet): - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) - """ - return model_pred model_pred = call_dit( @@ -551,6 +512,11 @@ def forward(hidden_states): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module ) -> torch.nn.Module: diff --git a/library/train_util.py b/library/train_util.py index b23290663..1f591c422 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -13,17 +13,7 @@ import shutil import time import typing -from typing import ( - Any, - Callable, - Dict, - List, - NamedTuple, - Optional, - Sequence, - Tuple, - Union -) +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob import math @@ -146,12 +136,13 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" + def split_train_val( - paths: List[str], + paths: List[str], sizes: List[Optional[Tuple[int, int]]], - is_training_dataset: bool, - validation_split: float, - validation_seed: int | None + is_training_dataset: bool, + validation_split: float, + validation_seed: int | None, ) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]: """ Split the dataset into train and validation @@ -1842,7 +1833,7 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): class DreamBoothDataset(BaseDataset): IMAGE_INFO_CACHE_FILE = "metadata_cache.json" - # The is_training_dataset defines the type of dataset, training or validation + # The is_training_dataset defines the type of dataset, training or validation # if is_training_dataset is True -> training dataset # if is_training_dataset is False -> validation dataset def __init__( @@ -1981,29 +1972,25 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") # We want to create a training and validation split. This should be improved in the future - # to allow a clearer distinction between training and validation. This can be seen as a + # to allow a clearer distinction between training and validation. This can be seen as a # short-term solution to limit what is necessary to implement validation datasets - # + # # We split the dataset for the subset based on if we are doing a validation split - # The self.is_training_dataset defines the type of dataset, training or validation + # The self.is_training_dataset defines the type of dataset, training or validation # if self.is_training_dataset is True -> training dataset # if self.is_training_dataset is False -> validation dataset if self.validation_split > 0.0: - # For regularization images we do not want to split this dataset. + # For regularization images we do not want to split this dataset. if subset.is_reg is True: # Skip any validation dataset for regularization images if self.is_training_dataset is False: img_paths = [] sizes = [] - # Otherwise the img_paths remain as original img_paths and no split + # Otherwise the img_paths remain as original img_paths and no split # required for training images dataset of regularization images else: img_paths, sizes = split_train_val( - img_paths, - sizes, - self.is_training_dataset, - self.validation_split, - self.validation_seed + img_paths, sizes, self.is_training_dataset, self.validation_split, self.validation_seed ) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") @@ -2373,7 +2360,7 @@ def __init__( bucket_no_upscale: bool, debug_dataset: bool, validation_split: float, - validation_seed: Optional[int], + validation_seed: Optional[int], ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2431,9 +2418,9 @@ def __init__( self.image_data = self.dreambooth_dataset_delegate.image_data self.batch_size = batch_size self.num_train_images = self.dreambooth_dataset_delegate.num_train_images - self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images self.validation_split = validation_split - self.validation_seed = validation_seed + self.validation_seed = validation_seed # assert all conditioning data exists missing_imgs = [] @@ -5944,12 +5931,17 @@ def save_sd_model_on_train_end_common( def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor: - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") + if min_timestep < max_timestep: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") + else: + timesteps = torch.full((b_size,), max_timestep, device="cpu") timesteps = timesteps.long().to(device) return timesteps -def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: +def get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents: torch.FloatTensor +) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: @@ -6441,7 +6433,7 @@ def sample_image_inference( wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption -def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): +def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): """ Initialize experiment trackers with tracker specific behaviors """ @@ -6458,13 +6450,17 @@ def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tr ) if "wandb" in [tracker.name for tracker in accelerator.trackers]: - import wandb + import wandb + wandb_tracker = accelerator.get_tracker("wandb", unwrap=True) # Define specific metrics to handle validation and epochs "steps" wandb_tracker.define_metric("epoch", hidden=True) wandb_tracker.define_metric("val_step", hidden=True) + wandb_tracker.define_metric("global_step", hidden=True) + + # endregion diff --git a/sd3_train_network.py b/sd3_train_network.py index 9438bc7bc..cdb7aa4e3 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -450,14 +450,19 @@ def forward(hidden_states): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) - def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): - # drop cached text encoder outputs + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True): + # drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) batch["text_encoder_outputs_list"] = text_encoder_outputs_list + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module ) -> torch.nn.Module: diff --git a/train_network.py b/train_network.py index a8a23fd06..2d279b3bf 100644 --- a/train_network.py +++ b/train_network.py @@ -9,6 +9,7 @@ import time import json from multiprocessing import Value +import numpy as np import toml from tqdm import tqdm @@ -100,9 +101,7 @@ def generate_step_logs( if ( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None ): # tracking d*lr value of unet. - logs["lr/d*lr"] = ( - optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] - ) + logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] else: idx = 0 if not args.network_train_unet_only: @@ -115,16 +114,56 @@ def generate_step_logs( logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) - if ( - args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None - ): - logs[f"lr/d*lr/group{i}"] = ( - optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] - ) + if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None: + logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] return logs - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def step_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int): + self.accelerator_logging(accelerator, logs, global_step, global_step, epoch) + + def epoch_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int): + self.accelerator_logging(accelerator, logs, epoch, global_step, epoch) + + def val_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int, val_step: int): + self.accelerator_logging(accelerator, logs, global_step + val_step, global_step, epoch, val_step) + + def accelerator_logging( + self, accelerator: Accelerator, logs: dict, step_value: int, global_step: int, epoch: int, val_step: Optional[int] = None + ): + """ + step_value is for tensorboard, other values are for wandb + """ + tensorboard_tracker = None + wandb_tracker = None + other_trackers = [] + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + tensorboard_tracker = accelerator.get_tracker("tensorboard") + elif tracker.name == "wandb": + wandb_tracker = accelerator.get_tracker("wandb") + else: + other_trackers.append(accelerator.get_tracker(tracker.name)) + + if tensorboard_tracker is not None: + tensorboard_tracker.log(logs, step=step_value) + + if wandb_tracker is not None: + logs["global_step"] = global_step + logs["epoch"] = epoch + if val_step is not None: + logs["val_step"] = val_step + wandb_tracker.log(logs) + + for tracker in other_trackers: + tracker.log(logs, step=step_value) + + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): train_dataset_group.verify_bucket_reso_steps(64) if val_dataset_group is not None: val_dataset_group.verify_bucket_reso_steps(64) @@ -219,7 +258,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -309,28 +348,31 @@ def prepare_unet_with_accelerator( ) -> torch.nn.Module: return accelerator.prepare(unet) - def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True): + pass + + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): pass # endregion def process_batch( - self, - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy: strategy_base.TextEncodingStrategy, - tokenize_strategy: strategy_base.TokenizeStrategy, - is_train=True, - train_text_encoder=True, - train_unet=True + self, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy: strategy_base.TextEncodingStrategy, + tokenize_strategy: strategy_base.TokenizeStrategy, + is_train=True, + train_text_encoder=True, + train_unet=True, ) -> torch.Tensor: """ Process a batch for the network @@ -397,7 +439,7 @@ def process_batch( network, weight_dtype, train_unet, - is_train=is_train + is_train=is_train, ) huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) @@ -484,7 +526,7 @@ def train(self, args): else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) - val_dataset_group = None # placeholder until validation dataset supported for arbitrary + val_dataset_group = None # placeholder until validation dataset supported for arbitrary current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -701,7 +743,7 @@ def train(self, args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) - + val_dataloader = torch.utils.data.DataLoader( val_dataset_group if val_dataset_group is not None else [], shuffle=False, @@ -900,7 +942,9 @@ def load_model_hook(models, input_dir): accelerator.print("running training / 学習開始") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}") + accelerator.print( + f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}" + ) accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") @@ -968,11 +1012,11 @@ def load_model_hook(models, input_dir): "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), - "ss_validation_seed": args.validation_seed, - "ss_validation_split": args.validation_split, - "ss_max_validation_steps": args.max_validation_steps, - "ss_validate_every_n_epochs": args.validate_every_n_epochs, - "ss_validate_every_n_steps": args.validate_every_n_steps, + "ss_validation_seed": args.validation_seed, + "ss_validation_split": args.validation_split, + "ss_max_validation_steps": args.max_validation_steps, + "ss_validate_every_n_epochs": args.validate_every_n_epochs, + "ss_validate_every_n_steps": args.validate_every_n_steps, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1243,12 +1287,6 @@ def remove_model(old_ckpt_name): # log empty object to commit the sample images to wandb accelerator.log({}, step=0) - validation_steps = ( - min(args.max_validation_steps, len(val_dataloader)) - if args.max_validation_steps is not None - else len(val_dataloader) - ) - # training loop if initial_step > 0: # only if skip_until_initial_step is specified for skip_epoch in range(epoch_to_start): # skip epochs @@ -1271,13 +1309,53 @@ def remove_model(old_ckpt_name): range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" ) + validation_steps = ( + min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + ) + NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable + min_timestep = 0 if args.min_timestep is None else args.min_timestep + max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep + validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1] + validation_total_steps = validation_steps * len(validation_timesteps) + original_args_min_timestep = args.min_timestep + original_args_max_timestep = args.max_timestep + + def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + cpu_rng_state = torch.get_rng_state() + if accelerator.device.type == "cuda": + gpu_rng_state = torch.cuda.get_rng_state() + elif accelerator.device.type == "xpu": + gpu_rng_state = torch.xpu.get_rng_state() + elif accelerator.device.type == "mps": + gpu_rng_state = torch.cuda.get_rng_state() + else: + gpu_rng_state = None + python_rng_state = random.getstate() + + torch.manual_seed(seed) + random.seed(seed) + + return (cpu_rng_state, gpu_rng_state, python_rng_state) + + def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): + cpu_rng_state, gpu_rng_state, python_rng_state = rng_states + torch.set_rng_state(cpu_rng_state) + if gpu_rng_state is not None: + if accelerator.device.type == "cuda": + torch.cuda.set_rng_state(gpu_rng_state) + elif accelerator.device.type == "xpu": + torch.xpu.set_rng_state(gpu_rng_state) + elif accelerator.device.type == "mps": + torch.cuda.set_rng_state(gpu_rng_state) + random.setstate(python_rng_state) + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1) - accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) + accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # network.train() is called here # TRAINING skipped_dataloader = None @@ -1294,25 +1372,25 @@ def remove_model(old_ckpt_name): with accelerator.accumulate(training_model): on_step_start_for_network(text_encoder, unet) - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + # preprocess batch for each model + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=True, - train_text_encoder=train_text_encoder, - train_unet=train_unet + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=True, + train_text_encoder=train_text_encoder, + train_unet=train_unet, ) accelerator.backward(loss) @@ -1369,148 +1447,167 @@ def remove_model(old_ckpt_name): if args.scale_weight_norms: progress_bar.set_postfix(**{**max_mean_logs, **logs}) - if is_tracking: logs = self.generate_step_logs( - args, - current_loss, - avr_loss, - lr_scheduler, - lr_descriptions, - optimizer, - keys_scaled, - mean_norm, - maximum_norm + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) - accelerator.log(logs, step=global_step) + self.step_logging(accelerator, logs, global_step, epoch + 1) - # VALIDATION PER STEP - should_validate_step = ( - args.validate_every_n_steps is not None - and global_step != 0 # Skip first step - and global_step % args.validate_every_n_steps == 0 - ) + # VALIDATION PER STEP: global_step is already incremented + # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ... + should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0 if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: + optimizer_eval_fn() + accelerator.unwrap_model(network).eval() + rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed) + val_progress_bar = tqdm( - range(validation_steps), smoothing=0, - disable=not accelerator.is_local_main_process, - desc="validation steps" + range(validation_total_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="validation steps", ) + val_timesteps_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - - loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=False, - train_text_encoder=False, - train_unet=False - ) - - current_loss = loss.detach().item() - val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) - val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) - - if is_tracking: - logs = { - "loss/validation/step_current": current_loss, - "val_step": (epoch * validation_steps) + val_step, - } - accelerator.log(logs, step=global_step) + for timestep in validation_timesteps: + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) + + args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep + + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False, + train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True + train_unet=train_unet, + ) + + current_loss = loss.detach().item() + val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix( + {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} + ) + + # if is_tracking: + # logs = {f"loss/validation/step_current_{timestep}": current_loss} + # self.val_logging(accelerator, logs, global_step, epoch + 1, val_step) + + self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + val_timesteps_step += 1 if is_tracking: loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average logs = { - "loss/validation/step_average": val_step_loss_recorder.moving_average, - "loss/validation/step_divergence": loss_validation_divergence, + "loss/validation/step_average": val_step_loss_recorder.moving_average, + "loss/validation/step_divergence": loss_validation_divergence, } - accelerator.log(logs, step=global_step) - + self.step_logging(accelerator, logs, global_step, epoch=epoch + 1) + + restore_rng_state(rng_states) + args.min_timestep = original_args_min_timestep + args.max_timestep = original_args_max_timestep + optimizer_train_fn() + accelerator.unwrap_model(network).train() + progress_bar.unpause() + if global_step >= args.max_train_steps: break # EPOCH VALIDATION should_validate_epoch = ( - (epoch + 1) % args.validate_every_n_epochs == 0 - if args.validate_every_n_epochs is not None - else True + (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True ) if should_validate_epoch and len(val_dataloader) > 0: + optimizer_eval_fn() + accelerator.unwrap_model(network).eval() + rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed) + val_progress_bar = tqdm( - range(validation_steps), smoothing=0, - disable=not accelerator.is_local_main_process, - desc="epoch validation steps" + range(validation_total_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="epoch validation steps", ) + val_timesteps_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + for timestep in validation_timesteps: + args.min_timestep = args.max_timestep = timestep - loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=False, - train_text_encoder=False, - train_unet=False - ) + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) - current_loss = loss.detach().item() - val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) - val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False, + train_text_encoder=train_text_encoder, + train_unet=train_unet, + ) - if is_tracking: - logs = { - "loss/validation/epoch_current": current_loss, - "epoch": epoch + 1, - "val_step": (epoch * validation_steps) + val_step - } - accelerator.log(logs, step=global_step) + current_loss = loss.detach().item() + val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix( + {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep} + ) + + # if is_tracking: + # logs = {f"loss/validation/epoch_current_{timestep}": current_loss} + # self.val_logging(accelerator, logs, global_step, epoch + 1, val_step) + + self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + val_timesteps_step += 1 if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average + loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average logs = { - "loss/validation/epoch_average": avr_loss, - "loss/validation/epoch_divergence": loss_validation_divergence, - "epoch": epoch + 1 + "loss/validation/epoch_average": avr_loss, + "loss/validation/epoch_divergence": loss_validation_divergence, } - accelerator.log(logs, step=global_step) + self.epoch_logging(accelerator, logs, global_step, epoch + 1) + + restore_rng_state(rng_states) + args.min_timestep = original_args_min_timestep + args.max_timestep = original_args_max_timestep + optimizer_train_fn() + accelerator.unwrap_model(network).train() + progress_bar.unpause() # END OF EPOCH if is_tracking: - logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} - accelerator.log(logs, step=global_step) - + logs = {"loss/epoch_average": loss_recorder.moving_average} + self.epoch_logging(accelerator, logs, global_step, epoch + 1) + accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 @@ -1696,31 +1793,31 @@ def setup_parser() -> argparse.ArgumentParser: "--validation_seed", type=int, default=None, - help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する" + help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する", ) parser.add_argument( "--validation_split", type=float, default=0.0, - help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" + help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合", ) parser.add_argument( "--validate_every_n_steps", type=int, default=None, - help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます" + help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます", ) parser.add_argument( "--validate_every_n_epochs", type=int, default=None, - help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます" + help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます", ) parser.add_argument( "--max_validation_steps", type=int, default=None, - help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します" + help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します", ) return parser