1313# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
1414# See the License for the specific language governing permissions and 
1515
16+ import  accelerate 
1617import  argparse 
1718import  copy 
1819import  functools 
1920import  logging 
2021import  math 
22+ import  numpy  as  np 
2123import  os 
2224import  random 
2325import  shutil 
24- from  contextlib  import  nullcontext 
25- from  pathlib  import  Path 
26- 
27- import  accelerate 
28- import  numpy  as  np 
2926import  torch 
3027import  torch .nn .functional  as  F 
3128import  torch .utils .checkpoint 
3229import  transformers 
30+ from  PIL  import  Image 
3331from  accelerate  import  Accelerator 
3432from  accelerate .logging  import  get_logger 
3533from  accelerate .utils  import  DistributedType , ProjectConfiguration , set_seed 
34+ from  contextlib  import  nullcontext 
3635from  datasets  import  load_dataset 
3736from  huggingface_hub  import  create_repo , upload_folder 
3837from  packaging  import  version 
39- from  PIL  import  Image 
38+ from  pathlib  import  Path 
4039from  torchvision  import  transforms 
4140from  tqdm .auto  import  tqdm 
4241from  transformers  import  (
6059from  diffusers .utils .import_utils  import  is_torch_npu_available , is_xformers_available 
6160from  diffusers .utils .torch_utils  import  is_compiled_module 
6261
63- 
6462if  is_wandb_available ():
6563    import  wandb 
6664
7371
7472
7573def  log_validation (
76-     vae , flux_transformer , flux_controlnet , args , accelerator , weight_dtype , step , is_final_validation = False 
74+          vae , flux_transformer , flux_controlnet , args , accelerator , weight_dtype , step , is_final_validation = False 
7775):
7876    logger .info ("Running validation... " )
7977
@@ -266,7 +264,7 @@ def parse_args(input_args=None):
266264        type = str ,
267265        default = None ,
268266        help = "Path to pretrained controlnet model or model identifier from huggingface.co/models." 
269-         " If not specified controlnet weights are initialized from unet." ,
267+               " If not specified controlnet weights are initialized from unet." ,
270268    )
271269    parser .add_argument (
272270        "--variant" ,
@@ -668,11 +666,11 @@ def parse_args(input_args=None):
668666        raise  ValueError ("`--validation_prompt` must be set if `--validation_image` is set" )
669667
670668    if  (
671-         args .validation_image  is  not None 
672-         and  args .validation_prompt  is  not None 
673-         and  len (args .validation_image ) !=  1 
674-         and  len (args .validation_prompt ) !=  1 
675-         and  len (args .validation_image ) !=  len (args .validation_prompt )
669+              args .validation_image  is  not None 
670+              and  args .validation_prompt  is  not None 
671+              and  len (args .validation_image ) !=  1 
672+              and  len (args .validation_prompt ) !=  1 
673+              and  len (args .validation_image ) !=  len (args .validation_prompt )
676674    ):
677675        raise  ValueError (
678676            "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," 
@@ -695,10 +693,12 @@ def get_train_dataset(args, accelerator):
695693            args .dataset_name ,
696694            args .dataset_config_name ,
697695            cache_dir = args .cache_dir ,
696+             trust_remote_code = args .trust_remote_code 
698697        )
699698    if  args .jsonl_for_train  is  not None :
700699        # load from json 
701-         dataset  =  load_dataset ("json" , data_files = args .jsonl_for_train , cache_dir = args .cache_dir )
700+         dataset  =  load_dataset ("json" , data_files = args .jsonl_for_train , cache_dir = args .cache_dir ,
701+                                trust_remote_code = args .trust_remote_code )
702702        dataset  =  dataset .flatten_indices ()
703703    # Preprocessing the datasets. 
704704    # We need to tokenize inputs and targets. 
@@ -1018,7 +1018,7 @@ def load_model_hook(models, input_dir):
10181018
10191019    if  args .scale_lr :
10201020        args .learning_rate  =  (
1021-             args .learning_rate  *  args .gradient_accumulation_steps  *  args .train_batch_size  *  accelerator .num_processes 
1021+                  args .learning_rate  *  args .gradient_accumulation_steps  *  args .train_batch_size  *  accelerator .num_processes 
10221022        )
10231023
10241024    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 
@@ -1130,7 +1130,7 @@ def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline
11301130        len_train_dataloader_after_sharding  =  math .ceil (len (train_dataloader ) /  accelerator .num_processes )
11311131        num_update_steps_per_epoch  =  math .ceil (len_train_dataloader_after_sharding  /  args .gradient_accumulation_steps )
11321132        num_training_steps_for_scheduler  =  (
1133-             args .num_train_epochs  *  num_update_steps_per_epoch  *  accelerator .num_processes 
1133+                  args .num_train_epochs  *  num_update_steps_per_epoch  *  accelerator .num_processes 
11341134        )
11351135    else :
11361136        num_training_steps_for_scheduler  =  args .max_train_steps  *  accelerator .num_processes 
0 commit comments