2424import  warnings 
2525from  contextlib  import  nullcontext 
2626from  pathlib  import  Path 
27- from  torch .utils .data .sampler  import  Sampler , BatchSampler 
2827
2928import  numpy  as  np 
3029import  torch 
4039from  PIL  import  Image 
4140from  PIL .ImageOps  import  exif_transpose 
4241from  torch .utils .data  import  Dataset 
42+ from  torch .utils .data .sampler  import  BatchSampler 
4343from  torchvision  import  transforms 
4444from  torchvision .transforms .functional  import  crop 
4545from  tqdm .auto  import  tqdm 
5757    cast_training_params ,
5858    compute_density_for_timestep_sampling ,
5959    compute_loss_weighting_for_sd3 ,
60+     find_nearest_bucket ,
6061    free_memory ,
6162    parse_buckets_string ,
62-     find_nearest_bucket 
6363)
6464from  diffusers .utils  import  (
6565    check_min_version ,
7070from  diffusers .utils .import_utils  import  is_torch_npu_available 
7171from  diffusers .utils .torch_utils  import  is_compiled_module 
7272
73+ 
7374if  is_wandb_available ():
7475    import  wandb 
7576
8182if  is_torch_npu_available ():
8283    torch .npu .config .allow_internal_format  =  False 
8384
85+ 
8486def  save_model_card (
85-          repo_id : str ,
86-          images = None ,
87-          base_model : str  =  None ,
88-          instance_prompt = None ,
89-          validation_prompt = None ,
90-          repo_folder = None ,
87+     repo_id : str ,
88+     images = None ,
89+     base_model : str  =  None ,
90+     instance_prompt = None ,
91+     validation_prompt = None ,
92+     repo_folder = None ,
9193):
9294    widget_dict  =  []
9395    if  images  is  not None :
@@ -189,13 +191,13 @@ def load_text_encoders(class_one, class_two, class_three):
189191
190192
191193def  log_validation (
192-          pipeline ,
193-          args ,
194-          accelerator ,
195-          pipeline_args ,
196-          epoch ,
197-          torch_dtype ,
198-          is_final_validation = False ,
194+     pipeline ,
195+     args ,
196+     accelerator ,
197+     pipeline_args ,
198+     epoch ,
199+     torch_dtype ,
200+     is_final_validation = False ,
199201):
200202    args .num_validation_images  =  args .num_validation_images  if  args .num_validation_images  else  1 
201203    logger .info (
@@ -244,7 +246,7 @@ def log_validation(
244246
245247
246248def  import_model_class_from_model_name_or_path (
247-          pretrained_model_name_or_path : str , revision : str , subfolder : str  =  "text_encoder" 
249+     pretrained_model_name_or_path : str , revision : str , subfolder : str  =  "text_encoder" 
248250):
249251    text_encoder_config  =  PretrainedConfig .from_pretrained (
250252        pretrained_model_name_or_path , subfolder = subfolder , revision = revision 
@@ -331,8 +333,8 @@ def parse_args(input_args=None):
331333        type = str ,
332334        default = "image" ,
333335        help = "The column of the dataset containing the target image. By " 
334-               "default, the standard Image Dataset maps out 'file_name' " 
335-               "to 'image'." ,
336+         "default, the standard Image Dataset maps out 'file_name' " 
337+         "to 'image'." ,
336338    )
337339    parser .add_argument (
338340        "--caption_column" ,
@@ -598,7 +600,7 @@ def parse_args(input_args=None):
598600        type = float ,
599601        default = None ,
600602        help = "coefficients for computing the Prodigy stepsize using running averages. If set to None, " 
601-               "uses the value of square root of beta2. Ignored if optimizer is adamW" ,
603+         "uses the value of square root of beta2. Ignored if optimizer is adamW" ,
602604    )
603605    parser .add_argument ("--prodigy_decouple" , type = bool , default = True , help = "Use AdamW style decoupled weight decay" )
604606    parser .add_argument ("--adam_weight_decay" , type = float , default = 1e-04 , help = "Weight decay to use for unet params" )
@@ -629,7 +631,7 @@ def parse_args(input_args=None):
629631        type = bool ,
630632        default = True ,
631633        help = "Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " 
632-               "Ignored if optimizer is adamW" ,
634+         "Ignored if optimizer is adamW" ,
633635    )
634636    parser .add_argument ("--max_grad_norm" , default = 1.0 , type = float , help = "Max gradient norm." )
635637    parser .add_argument ("--push_to_hub" , action = "store_true" , help = "Whether or not to push the model to the Hub." )
@@ -736,17 +738,17 @@ class DreamBoothDataset(Dataset):
736738    """ 
737739
738740    def  __init__ (
739-              self ,
740-              instance_data_root ,
741-              instance_prompt ,
742-              class_prompt ,
743-              class_data_root = None ,
744-              class_num = None ,
745-              size = 1024 ,
746-              repeats = 1 ,
747-              center_crop = False ,
748-              buckets = [(1024 ,1024 ),(768 ,1360 ),(1360 , 768 ),(880 , 1168 ),(1168 , 880 ), (1248 , 832 ), (832 , 1248 )],
749-              # buckets=[(1024, 1024)], 
741+         self ,
742+         instance_data_root ,
743+         instance_prompt ,
744+         class_prompt ,
745+         class_data_root = None ,
746+         class_num = None ,
747+         size = 1024 ,
748+         repeats = 1 ,
749+         center_crop = False ,
750+         buckets = [(1024 ,  1024 ),  (768 ,  1360 ),  (1360 , 768 ),  (880 , 1168 ),  (1168 , 880 ), (1248 , 832 ), (832 , 1248 )],
751+         # buckets=[(1024, 1024)], 
750752    ):
751753        # self.size = (size, size) 
752754        self .center_crop  =  center_crop 
@@ -930,11 +932,9 @@ def collate_fn(examples, with_prior_preservation=False):
930932class  BucketBatchSampler (BatchSampler ):
931933    def  __init__ (self , dataset : DreamBoothDataset , batch_size : int , drop_last : bool  =  False ):
932934        if  not  isinstance (batch_size , int ) or  batch_size  <=  0 :
933-             raise  ValueError ("batch_size should be a positive integer value, " 
934-                              "but got batch_size={}" .format (batch_size ))
935+             raise  ValueError ("batch_size should be a positive integer value, but got batch_size={}" .format (batch_size ))
935936        if  not  isinstance (drop_last , bool ):
936-             raise  ValueError ("drop_last should be a boolean value, but got " 
937-                              "drop_last={}" .format (drop_last ))
937+             raise  ValueError ("drop_last should be a boolean value, but got drop_last={}" .format (drop_last ))
938938
939939        self .dataset  =  dataset 
940940        self .batch_size  =  batch_size 
@@ -954,7 +954,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
954954            random .shuffle (indices_in_bucket )
955955            # Create batches 
956956            for  i  in  range (0 , len (indices_in_bucket ), self .batch_size ):
957-                 batch  =  indices_in_bucket [i : i  +  self .batch_size ]
957+                 batch  =  indices_in_bucket [i  :  i  +  self .batch_size ]
958958                if  len (batch ) <  self .batch_size  and  self .drop_last :
959959                    continue   # Skip partial batch if drop_last is True 
960960                self .batches .append (batch )
@@ -1064,7 +1064,7 @@ def main(args):
10641064            pipeline .to (accelerator .device )
10651065
10661066            for  example  in  tqdm (
1067-                      sample_dataloader , desc = "Generating class images" , disable = not  accelerator .is_local_main_process 
1067+                 sample_dataloader , desc = "Generating class images" , disable = not  accelerator .is_local_main_process 
10681068            ):
10691069                images  =  pipeline (example ["prompt" ]).images 
10701070
@@ -1278,7 +1278,7 @@ def load_model_hook(models, input_dir):
12781278
12791279    if  args .scale_lr :
12801280        args .learning_rate  =  (
1281-                  args .learning_rate  *  args .gradient_accumulation_steps  *  args .train_batch_size  *  accelerator .num_processes 
1281+             args .learning_rate  *  args .gradient_accumulation_steps  *  args .train_batch_size  *  accelerator .num_processes 
12821282        )
12831283
12841284    # Make sure the trainable params are in float32. 
@@ -1368,10 +1368,7 @@ def load_model_hook(models, input_dir):
13681368        repeats = args .repeats ,
13691369        center_crop = args .center_crop ,
13701370    )
1371-     batch_sampler  =  BucketBatchSampler (
1372-         train_dataset ,
1373-         batch_size = args .train_batch_size ,
1374-         drop_last = False )
1371+     batch_sampler  =  BucketBatchSampler (train_dataset , batch_size = args .train_batch_size , drop_last = False )
13751372    train_dataloader  =  torch .utils .data .DataLoader (
13761373        train_dataset ,
13771374        batch_sampler = batch_sampler ,
0 commit comments