5252from diffusers .utils .torch_utils import is_compiled_module
5353
5454
55+ if is_wandb_available ():
56+ import wandb
57+
5558# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
5659check_min_version ("0.29.0.dev0" )
5760
@@ -99,6 +102,48 @@ def save_model_card(
99102 model_card .save (os .path .join (repo_folder , "README.md" ))
100103
101104
105+ def log_validation (
106+ pipeline ,
107+ args ,
108+ accelerator ,
109+ epoch ,
110+ is_final_validation = False ,
111+ ):
112+ logger .info (
113+ f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
114+ f" { args .validation_prompt } ."
115+ )
116+ pipeline = pipeline .to (accelerator .device )
117+ pipeline .set_progress_bar_config (disable = True )
118+ generator = torch .Generator (device = accelerator .device )
119+ if args .seed is not None :
120+ generator = generator .manual_seed (args .seed )
121+ images = []
122+ if torch .backends .mps .is_available ():
123+ autocast_ctx = nullcontext ()
124+ else :
125+ autocast_ctx = torch .autocast (accelerator .device .type )
126+
127+ with autocast_ctx :
128+ for _ in range (args .num_validation_images ):
129+ images .append (pipeline (args .validation_prompt , num_inference_steps = 30 , generator = generator ).images [0 ])
130+
131+ for tracker in accelerator .trackers :
132+ phase_name = "test" if is_final_validation else "validation"
133+ if tracker .name == "tensorboard" :
134+ np_images = np .stack ([np .asarray (img ) for img in images ])
135+ tracker .writer .add_images (phase_name , np_images , epoch , dataformats = "NHWC" )
136+ if tracker .name == "wandb" :
137+ tracker .log (
138+ {
139+ phase_name : [
140+ wandb .Image (image , caption = f"{ i } : { args .validation_prompt } " ) for i , image in enumerate (images )
141+ ]
142+ }
143+ )
144+ return images
145+
146+
102147def parse_args ():
103148 parser = argparse .ArgumentParser (description = "Simple example of a training script." )
104149 parser .add_argument (
@@ -414,11 +459,6 @@ def main():
414459 if torch .backends .mps .is_available ():
415460 accelerator .native_amp = False
416461
417- if args .report_to == "wandb" :
418- if not is_wandb_available ():
419- raise ImportError ("Make sure to install wandb if you want to use it for logging during training." )
420- import wandb
421-
422462 # Make one log on every process with the configuration for debugging.
423463 logging .basicConfig (
424464 format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
@@ -864,10 +904,6 @@ def collate_fn(examples):
864904
865905 if accelerator .is_main_process :
866906 if args .validation_prompt is not None and epoch % args .validation_epochs == 0 :
867- logger .info (
868- f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
869- f" { args .validation_prompt } ."
870- )
871907 # create pipeline
872908 pipeline = DiffusionPipeline .from_pretrained (
873909 args .pretrained_model_name_or_path ,
@@ -876,38 +912,7 @@ def collate_fn(examples):
876912 variant = args .variant ,
877913 torch_dtype = weight_dtype ,
878914 )
879- pipeline = pipeline .to (accelerator .device )
880- pipeline .set_progress_bar_config (disable = True )
881-
882- # run inference
883- generator = torch .Generator (device = accelerator .device )
884- if args .seed is not None :
885- generator = generator .manual_seed (args .seed )
886- images = []
887- if torch .backends .mps .is_available ():
888- autocast_ctx = nullcontext ()
889- else :
890- autocast_ctx = torch .autocast (accelerator .device .type )
891-
892- with autocast_ctx :
893- for _ in range (args .num_validation_images ):
894- images .append (
895- pipeline (args .validation_prompt , num_inference_steps = 30 , generator = generator ).images [0 ]
896- )
897-
898- for tracker in accelerator .trackers :
899- if tracker .name == "tensorboard" :
900- np_images = np .stack ([np .asarray (img ) for img in images ])
901- tracker .writer .add_images ("validation" , np_images , epoch , dataformats = "NHWC" )
902- if tracker .name == "wandb" :
903- tracker .log (
904- {
905- "validation" : [
906- wandb .Image (image , caption = f"{ i } : { args .validation_prompt } " )
907- for i , image in enumerate (images )
908- ]
909- }
910- )
915+ images = log_validation (pipeline , args , accelerator , epoch )
911916
912917 del pipeline
913918 torch .cuda .empty_cache ()
@@ -925,21 +930,6 @@ def collate_fn(examples):
925930 safe_serialization = True ,
926931 )
927932
928- if args .push_to_hub :
929- save_model_card (
930- repo_id ,
931- images = images ,
932- base_model = args .pretrained_model_name_or_path ,
933- dataset_name = args .dataset_name ,
934- repo_folder = args .output_dir ,
935- )
936- upload_folder (
937- repo_id = repo_id ,
938- folder_path = args .output_dir ,
939- commit_message = "End of training" ,
940- ignore_patterns = ["step_*" , "epoch_*" ],
941- )
942-
943933 # Final inference
944934 # Load previous pipeline
945935 if args .validation_prompt is not None :
@@ -949,41 +939,27 @@ def collate_fn(examples):
949939 variant = args .variant ,
950940 torch_dtype = weight_dtype ,
951941 )
952- pipeline = pipeline .to (accelerator .device )
953942
954943 # load attention processors
955944 pipeline .load_lora_weights (args .output_dir )
956945
957946 # run inference
958- generator = torch .Generator (device = accelerator .device )
959- if args .seed is not None :
960- generator = generator .manual_seed (args .seed )
961- images = []
962- if torch .backends .mps .is_available ():
963- autocast_ctx = nullcontext ()
964- else :
965- autocast_ctx = torch .autocast (accelerator .device .type )
966-
967- with autocast_ctx :
968- for _ in range (args .num_validation_images ):
969- images .append (
970- pipeline (args .validation_prompt , num_inference_steps = 30 , generator = generator ).images [0 ]
971- )
947+ images = log_validation (pipeline , args , accelerator , epoch , is_final_validation = True )
972948
973- for tracker in accelerator . trackers :
974- if len ( images ) != 0 :
975- if tracker . name == "tensorboard" :
976- np_images = np . stack ([ np . asarray ( img ) for img in images ])
977- tracker . writer . add_images ( "test" , np_images , epoch , dataformats = "NHWC" )
978- if tracker . name == "wandb" :
979- tracker . log (
980- {
981- "test" : [
982- wandb . Image ( image , caption = f" { i } : { args . validation_prompt } " )
983- for i , image in enumerate ( images )
984- ]
985- }
986- )
949+ if args . push_to_hub :
950+ save_model_card (
951+ repo_id ,
952+ images = images ,
953+ base_model = args . pretrained_model_name_or_path ,
954+ dataset_name = args . dataset_name ,
955+ repo_folder = args . output_dir ,
956+ )
957+ upload_folder (
958+ repo_id = repo_id ,
959+ folder_path = args . output_dir ,
960+ commit_message = "End of training" ,
961+ ignore_patterns = [ "step_*" , "epoch_*" ],
962+ )
987963
988964 accelerator .end_training ()
989965
0 commit comments