1313import shutil
1414import time
1515import typing
16- from typing import (
17- Any ,
18- Callable ,
19- Dict ,
20- List ,
21- NamedTuple ,
22- Optional ,
23- Sequence ,
24- Tuple ,
25- Union
26- )
16+ from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence , Tuple , Union
2717from accelerate import Accelerator , InitProcessGroupKwargs , DistributedDataParallelKwargs , PartialState
2818import glob
2919import math
146136TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
147137TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
148138
139+
149140def split_train_val (
150- paths : List [str ],
141+ paths : List [str ],
151142 sizes : List [Optional [Tuple [int , int ]]],
152- is_training_dataset : bool ,
153- validation_split : float ,
154- validation_seed : int | None
143+ is_training_dataset : bool ,
144+ validation_split : float ,
145+ validation_seed : int | None ,
155146) -> Tuple [List [str ], List [Optional [Tuple [int , int ]]]]:
156147 """
157148 Split the dataset into train and validation
@@ -1842,7 +1833,7 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index):
18421833class DreamBoothDataset (BaseDataset ):
18431834 IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
18441835
1845- # The is_training_dataset defines the type of dataset, training or validation
1836+ # The is_training_dataset defines the type of dataset, training or validation
18461837 # if is_training_dataset is True -> training dataset
18471838 # if is_training_dataset is False -> validation dataset
18481839 def __init__ (
@@ -1981,29 +1972,25 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
19811972 logger .info (f"set image size from cache files: { size_set_count } /{ len (img_paths )} " )
19821973
19831974 # We want to create a training and validation split. This should be improved in the future
1984- # to allow a clearer distinction between training and validation. This can be seen as a
1975+ # to allow a clearer distinction between training and validation. This can be seen as a
19851976 # short-term solution to limit what is necessary to implement validation datasets
1986- #
1977+ #
19871978 # We split the dataset for the subset based on if we are doing a validation split
1988- # The self.is_training_dataset defines the type of dataset, training or validation
1979+ # The self.is_training_dataset defines the type of dataset, training or validation
19891980 # if self.is_training_dataset is True -> training dataset
19901981 # if self.is_training_dataset is False -> validation dataset
19911982 if self .validation_split > 0.0 :
1992- # For regularization images we do not want to split this dataset.
1983+ # For regularization images we do not want to split this dataset.
19931984 if subset .is_reg is True :
19941985 # Skip any validation dataset for regularization images
19951986 if self .is_training_dataset is False :
19961987 img_paths = []
19971988 sizes = []
1998- # Otherwise the img_paths remain as original img_paths and no split
1989+ # Otherwise the img_paths remain as original img_paths and no split
19991990 # required for training images dataset of regularization images
20001991 else :
20011992 img_paths , sizes = split_train_val (
2002- img_paths ,
2003- sizes ,
2004- self .is_training_dataset ,
2005- self .validation_split ,
2006- self .validation_seed
1993+ img_paths , sizes , self .is_training_dataset , self .validation_split , self .validation_seed
20071994 )
20081995
20091996 logger .info (f"found directory { subset .image_dir } contains { len (img_paths )} image files" )
@@ -2373,7 +2360,7 @@ def __init__(
23732360 bucket_no_upscale : bool ,
23742361 debug_dataset : bool ,
23752362 validation_split : float ,
2376- validation_seed : Optional [int ],
2363+ validation_seed : Optional [int ],
23772364 ) -> None :
23782365 super ().__init__ (resolution , network_multiplier , debug_dataset )
23792366
@@ -2431,9 +2418,9 @@ def __init__(
24312418 self .image_data = self .dreambooth_dataset_delegate .image_data
24322419 self .batch_size = batch_size
24332420 self .num_train_images = self .dreambooth_dataset_delegate .num_train_images
2434- self .num_reg_images = self .dreambooth_dataset_delegate .num_reg_images
2421+ self .num_reg_images = self .dreambooth_dataset_delegate .num_reg_images
24352422 self .validation_split = validation_split
2436- self .validation_seed = validation_seed
2423+ self .validation_seed = validation_seed
24372424
24382425 # assert all conditioning data exists
24392426 missing_imgs = []
@@ -5944,12 +5931,17 @@ def save_sd_model_on_train_end_common(
59445931
59455932
59465933def get_timesteps (min_timestep : int , max_timestep : int , b_size : int , device : torch .device ) -> torch .Tensor :
5947- timesteps = torch .randint (min_timestep , max_timestep , (b_size ,), device = "cpu" )
5934+ if min_timestep < max_timestep :
5935+ timesteps = torch .randint (min_timestep , max_timestep , (b_size ,), device = "cpu" )
5936+ else :
5937+ timesteps = torch .full ((b_size ,), max_timestep , device = "cpu" )
59485938 timesteps = timesteps .long ().to (device )
59495939 return timesteps
59505940
59515941
5952- def get_noise_noisy_latents_and_timesteps (args , noise_scheduler , latents : torch .FloatTensor ) -> Tuple [torch .FloatTensor , torch .FloatTensor , torch .IntTensor ]:
5942+ def get_noise_noisy_latents_and_timesteps (
5943+ args , noise_scheduler , latents : torch .FloatTensor
5944+ ) -> Tuple [torch .FloatTensor , torch .FloatTensor , torch .IntTensor ]:
59535945 # Sample noise that we'll add to the latents
59545946 noise = torch .randn_like (latents , device = latents .device )
59555947 if args .noise_offset :
@@ -6441,7 +6433,7 @@ def sample_image_inference(
64416433 wandb_tracker .log ({f"sample_{ i } " : wandb .Image (image , caption = prompt )}, commit = False ) # positive prompt as a caption
64426434
64436435
6444- def init_trackers (accelerator : Accelerator , args : argparse .Namespace , default_tracker_name : str ):
6436+ def init_trackers (accelerator : Accelerator , args : argparse .Namespace , default_tracker_name : str ):
64456437 """
64466438 Initialize experiment trackers with tracker specific behaviors
64476439 """
@@ -6458,13 +6450,17 @@ def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tr
64586450 )
64596451
64606452 if "wandb" in [tracker .name for tracker in accelerator .trackers ]:
6461- import wandb
6453+ import wandb
6454+
64626455 wandb_tracker = accelerator .get_tracker ("wandb" , unwrap = True )
64636456
64646457 # Define specific metrics to handle validation and epochs "steps"
64656458 wandb_tracker .define_metric ("epoch" , hidden = True )
64666459 wandb_tracker .define_metric ("val_step" , hidden = True )
64676460
6461+ wandb_tracker .define_metric ("global_step" , hidden = True )
6462+
6463+
64686464# endregion
64696465
64706466
0 commit comments