6464from diffusers .utils .torch_utils import is_compiled_module
6565
6666
67+ if is_wandb_available ():
68+ import wandb
69+
6770# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
6871check_min_version ("0.29.0.dev0" )
6972
@@ -119,6 +122,47 @@ def save_model_card(
119122 model_card .save (os .path .join (repo_folder , "README.md" ))
120123
121124
125+ def log_validation (
126+ pipeline ,
127+ args ,
128+ accelerator ,
129+ epoch ,
130+ is_final_validation = False ,
131+ ):
132+ logger .info (
133+ f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
134+ f" { args .validation_prompt } ."
135+ )
136+ pipeline = pipeline .to (accelerator .device )
137+ pipeline .set_progress_bar_config (disable = True )
138+
139+ # run inference
140+ generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if args .seed else None
141+ pipeline_args = {"prompt" : args .validation_prompt }
142+ if torch .backends .mps .is_available ():
143+ autocast_ctx = nullcontext ()
144+ else :
145+ autocast_ctx = torch .autocast (accelerator .device .type )
146+
147+ with autocast_ctx :
148+ images = [pipeline (** pipeline_args , generator = generator ).images [0 ] for _ in range (args .num_validation_images )]
149+
150+ for tracker in accelerator .trackers :
151+ phase_name = "test" if is_final_validation else "validation"
152+ if tracker .name == "tensorboard" :
153+ np_images = np .stack ([np .asarray (img ) for img in images ])
154+ tracker .writer .add_images (phase_name , np_images , epoch , dataformats = "NHWC" )
155+ if tracker .name == "wandb" :
156+ tracker .log (
157+ {
158+ phase_name : [
159+ wandb .Image (image , caption = f"{ i } : { args .validation_prompt } " ) for i , image in enumerate (images )
160+ ]
161+ }
162+ )
163+ return images
164+
165+
122166def import_model_class_from_model_name_or_path (
123167 pretrained_model_name_or_path : str , revision : str , subfolder : str = "text_encoder"
124168):
@@ -523,11 +567,6 @@ def main(args):
523567 kwargs_handlers = [kwargs ],
524568 )
525569
526- if args .report_to == "wandb" :
527- if not is_wandb_available ():
528- raise ImportError ("Make sure to install wandb if you want to use it for logging during training." )
529- import wandb
530-
531570 # Make one log on every process with the configuration for debugging.
532571 logging .basicConfig (
533572 format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
@@ -1196,10 +1235,6 @@ def compute_time_ids(original_size, crops_coords_top_left):
11961235
11971236 if accelerator .is_main_process :
11981237 if args .validation_prompt is not None and epoch % args .validation_epochs == 0 :
1199- logger .info (
1200- f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
1201- f" { args .validation_prompt } ."
1202- )
12031238 # create pipeline
12041239 pipeline = StableDiffusionXLPipeline .from_pretrained (
12051240 args .pretrained_model_name_or_path ,
@@ -1212,36 +1247,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
12121247 torch_dtype = weight_dtype ,
12131248 )
12141249
1215- pipeline = pipeline .to (accelerator .device )
1216- pipeline .set_progress_bar_config (disable = True )
1217-
1218- # run inference
1219- generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if args .seed else None
1220- pipeline_args = {"prompt" : args .validation_prompt }
1221- if torch .backends .mps .is_available ():
1222- autocast_ctx = nullcontext ()
1223- else :
1224- autocast_ctx = torch .autocast (accelerator .device .type )
1225-
1226- with autocast_ctx :
1227- images = [
1228- pipeline (** pipeline_args , generator = generator ).images [0 ]
1229- for _ in range (args .num_validation_images )
1230- ]
1231-
1232- for tracker in accelerator .trackers :
1233- if tracker .name == "tensorboard" :
1234- np_images = np .stack ([np .asarray (img ) for img in images ])
1235- tracker .writer .add_images ("validation" , np_images , epoch , dataformats = "NHWC" )
1236- if tracker .name == "wandb" :
1237- tracker .log (
1238- {
1239- "validation" : [
1240- wandb .Image (image , caption = f"{ i } : { args .validation_prompt } " )
1241- for i , image in enumerate (images )
1242- ]
1243- }
1244- )
1250+ images = log_validation (pipeline , args , accelerator , epoch )
12451251
12461252 del pipeline
12471253 torch .cuda .empty_cache ()
@@ -1288,33 +1294,13 @@ def compute_time_ids(original_size, crops_coords_top_left):
12881294 variant = args .variant ,
12891295 torch_dtype = weight_dtype ,
12901296 )
1291- pipeline = pipeline .to (accelerator .device )
12921297
12931298 # load attention processors
12941299 pipeline .load_lora_weights (args .output_dir )
12951300
12961301 # run inference
1297- images = []
12981302 if args .validation_prompt and args .num_validation_images > 0 :
1299- generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if args .seed else None
1300- images = [
1301- pipeline (args .validation_prompt , num_inference_steps = 25 , generator = generator ).images [0 ]
1302- for _ in range (args .num_validation_images )
1303- ]
1304-
1305- for tracker in accelerator .trackers :
1306- if tracker .name == "tensorboard" :
1307- np_images = np .stack ([np .asarray (img ) for img in images ])
1308- tracker .writer .add_images ("test" , np_images , epoch , dataformats = "NHWC" )
1309- if tracker .name == "wandb" :
1310- tracker .log (
1311- {
1312- "test" : [
1313- wandb .Image (image , caption = f"{ i } : { args .validation_prompt } " )
1314- for i , image in enumerate (images )
1315- ]
1316- }
1317- )
1303+ images = log_validation (pipeline , args , accelerator , epoch , is_final_validation = True )
13181304
13191305 if args .push_to_hub :
13201306 save_model_card (
0 commit comments