6666from diffusers .utils .torch_utils import is_compiled_module
6767
6868
69+ if is_wandb_available ():
70+ import wandb
71+
6972# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
7073check_min_version ("0.27.0.dev0" )
7174
@@ -113,6 +116,71 @@ def save_model_card(
113116 model_card .save (os .path .join (repo_folder , "README.md" ))
114117
115118
119+ def log_validation (
120+ pipeline ,
121+ args ,
122+ accelerator ,
123+ pipeline_args ,
124+ epoch ,
125+ is_final_validation = False ,
126+ ):
127+ logger .info (
128+ f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
129+ f" { args .validation_prompt } ."
130+ )
131+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
132+ scheduler_args = {}
133+
134+ if "variance_type" in pipeline .scheduler .config :
135+ variance_type = pipeline .scheduler .config .variance_type
136+
137+ if variance_type in ["learned" , "learned_range" ]:
138+ variance_type = "fixed_small"
139+
140+ scheduler_args ["variance_type" ] = variance_type
141+
142+ pipeline .scheduler = DPMSolverMultistepScheduler .from_config (pipeline .scheduler .config , ** scheduler_args )
143+
144+ pipeline = pipeline .to (accelerator .device )
145+ pipeline .set_progress_bar_config (disable = True )
146+
147+ # run inference
148+ generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if args .seed else None
149+
150+ if args .validation_images is None :
151+ images = []
152+ for _ in range (args .num_validation_images ):
153+ with torch .cuda .amp .autocast ():
154+ image = pipeline (** pipeline_args , generator = generator ).images [0 ]
155+ images .append (image )
156+ else :
157+ images = []
158+ for image in args .validation_images :
159+ image = Image .open (image )
160+ with torch .cuda .amp .autocast ():
161+ image = pipeline (** pipeline_args , image = image , generator = generator ).images [0 ]
162+ images .append (image )
163+
164+ for tracker in accelerator .trackers :
165+ phase_name = "test" if is_final_validation else "validation"
166+ if tracker .name == "tensorboard" :
167+ np_images = np .stack ([np .asarray (img ) for img in images ])
168+ tracker .writer .add_images (phase_name , np_images , epoch , dataformats = "NHWC" )
169+ if tracker .name == "wandb" :
170+ tracker .log (
171+ {
172+ phase_name : [
173+ wandb .Image (image , caption = f"{ i } : { args .validation_prompt } " ) for i , image in enumerate (images )
174+ ]
175+ }
176+ )
177+
178+ del pipeline
179+ torch .cuda .empty_cache ()
180+
181+ return images
182+
183+
116184def import_model_class_from_model_name_or_path (pretrained_model_name_or_path : str , revision : str ):
117185 text_encoder_config = PretrainedConfig .from_pretrained (
118186 pretrained_model_name_or_path ,
@@ -684,7 +752,6 @@ def main(args):
684752 if args .report_to == "wandb" :
685753 if not is_wandb_available ():
686754 raise ImportError ("Make sure to install wandb if you want to use it for logging during training." )
687- import wandb
688755
689756 # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
690757 # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
@@ -1265,10 +1332,6 @@ def compute_text_embeddings(prompt):
12651332
12661333 if accelerator .is_main_process :
12671334 if args .validation_prompt is not None and epoch % args .validation_epochs == 0 :
1268- logger .info (
1269- f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
1270- f" { args .validation_prompt } ."
1271- )
12721335 # create pipeline
12731336 pipeline = DiffusionPipeline .from_pretrained (
12741337 args .pretrained_model_name_or_path ,
@@ -1279,26 +1342,6 @@ def compute_text_embeddings(prompt):
12791342 torch_dtype = weight_dtype ,
12801343 )
12811344
1282- # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1283- scheduler_args = {}
1284-
1285- if "variance_type" in pipeline .scheduler .config :
1286- variance_type = pipeline .scheduler .config .variance_type
1287-
1288- if variance_type in ["learned" , "learned_range" ]:
1289- variance_type = "fixed_small"
1290-
1291- scheduler_args ["variance_type" ] = variance_type
1292-
1293- pipeline .scheduler = DPMSolverMultistepScheduler .from_config (
1294- pipeline .scheduler .config , ** scheduler_args
1295- )
1296-
1297- pipeline = pipeline .to (accelerator .device )
1298- pipeline .set_progress_bar_config (disable = True )
1299-
1300- # run inference
1301- generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if args .seed else None
13021345 if args .pre_compute_text_embeddings :
13031346 pipeline_args = {
13041347 "prompt_embeds" : validation_prompt_encoder_hidden_states ,
@@ -1307,36 +1350,13 @@ def compute_text_embeddings(prompt):
13071350 else :
13081351 pipeline_args = {"prompt" : args .validation_prompt }
13091352
1310- if args .validation_images is None :
1311- images = []
1312- for _ in range (args .num_validation_images ):
1313- with torch .cuda .amp .autocast ():
1314- image = pipeline (** pipeline_args , generator = generator ).images [0 ]
1315- images .append (image )
1316- else :
1317- images = []
1318- for image in args .validation_images :
1319- image = Image .open (image )
1320- with torch .cuda .amp .autocast ():
1321- image = pipeline (** pipeline_args , image = image , generator = generator ).images [0 ]
1322- images .append (image )
1323-
1324- for tracker in accelerator .trackers :
1325- if tracker .name == "tensorboard" :
1326- np_images = np .stack ([np .asarray (img ) for img in images ])
1327- tracker .writer .add_images ("validation" , np_images , epoch , dataformats = "NHWC" )
1328- if tracker .name == "wandb" :
1329- tracker .log (
1330- {
1331- "validation" : [
1332- wandb .Image (image , caption = f"{ i } : { args .validation_prompt } " )
1333- for i , image in enumerate (images )
1334- ]
1335- }
1336- )
1337-
1338- del pipeline
1339- torch .cuda .empty_cache ()
1353+ images = log_validation (
1354+ pipeline ,
1355+ args ,
1356+ accelerator ,
1357+ pipeline_args ,
1358+ epoch ,
1359+ )
13401360
13411361 # Save the lora layers
13421362 accelerator .wait_for_everyone ()
@@ -1364,46 +1384,21 @@ def compute_text_embeddings(prompt):
13641384 args .pretrained_model_name_or_path , revision = args .revision , variant = args .variant , torch_dtype = weight_dtype
13651385 )
13661386
1367- # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1368- scheduler_args = {}
1369-
1370- if "variance_type" in pipeline .scheduler .config :
1371- variance_type = pipeline .scheduler .config .variance_type
1372-
1373- if variance_type in ["learned" , "learned_range" ]:
1374- variance_type = "fixed_small"
1375-
1376- scheduler_args ["variance_type" ] = variance_type
1377-
1378- pipeline .scheduler = DPMSolverMultistepScheduler .from_config (pipeline .scheduler .config , ** scheduler_args )
1379-
1380- pipeline = pipeline .to (accelerator .device )
1381-
13821387 # load attention processors
13831388 pipeline .load_lora_weights (args .output_dir , weight_name = "pytorch_lora_weights.safetensors" )
13841389
13851390 # run inference
13861391 images = []
13871392 if args .validation_prompt and args .num_validation_images > 0 :
1388- generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if args .seed else None
1389- images = [
1390- pipeline (args .validation_prompt , num_inference_steps = 25 , generator = generator ).images [0 ]
1391- for _ in range (args .num_validation_images )
1392- ]
1393-
1394- for tracker in accelerator .trackers :
1395- if tracker .name == "tensorboard" :
1396- np_images = np .stack ([np .asarray (img ) for img in images ])
1397- tracker .writer .add_images ("test" , np_images , epoch , dataformats = "NHWC" )
1398- if tracker .name == "wandb" :
1399- tracker .log (
1400- {
1401- "test" : [
1402- wandb .Image (image , caption = f"{ i } : { args .validation_prompt } " )
1403- for i , image in enumerate (images )
1404- ]
1405- }
1406- )
1393+ pipeline_args = {"prompt" : args .validation_prompt , "num_inference_steps" : 25 }
1394+ images = log_validation (
1395+ pipeline ,
1396+ args ,
1397+ accelerator ,
1398+ pipeline_args ,
1399+ epoch ,
1400+ is_final_validation = True ,
1401+ )
14071402
14081403 if args .push_to_hub :
14091404 save_model_card (
0 commit comments