6767from diffusers .utils .torch_utils import is_compiled_module
6868
6969
70+ if is_wandb_available ():
71+ import wandb
72+
7073# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
7174check_min_version ("0.27.0.dev0" )
7275
@@ -140,6 +143,61 @@ def save_model_card(
140143 model_card .save (os .path .join (repo_folder , "README.md" ))
141144
142145
146+ def log_validation (
147+ pipeline ,
148+ args ,
149+ accelerator ,
150+ pipeline_args ,
151+ epoch ,
152+ is_final_validation = False ,
153+ ):
154+ logger .info (
155+ f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
156+ f" { args .validation_prompt } ."
157+ )
158+
159+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
160+ scheduler_args = {}
161+
162+ if "variance_type" in pipeline .scheduler .config :
163+ variance_type = pipeline .scheduler .config .variance_type
164+
165+ if variance_type in ["learned" , "learned_range" ]:
166+ variance_type = "fixed_small"
167+
168+ scheduler_args ["variance_type" ] = variance_type
169+
170+ pipeline .scheduler = DPMSolverMultistepScheduler .from_config (pipeline .scheduler .config , ** scheduler_args )
171+
172+ pipeline = pipeline .to (accelerator .device )
173+ pipeline .set_progress_bar_config (disable = True )
174+
175+ # run inference
176+ generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if args .seed else None
177+
178+ with torch .cuda .amp .autocast ():
179+ images = [pipeline (** pipeline_args , generator = generator ).images [0 ] for _ in range (args .num_validation_images )]
180+
181+ for tracker in accelerator .trackers :
182+ phase_name = "test" if is_final_validation else "validation"
183+ if tracker .name == "tensorboard" :
184+ np_images = np .stack ([np .asarray (img ) for img in images ])
185+ tracker .writer .add_images (phase_name , np_images , epoch , dataformats = "NHWC" )
186+ if tracker .name == "wandb" :
187+ tracker .log (
188+ {
189+ phase_name : [
190+ wandb .Image (image , caption = f"{ i } : { args .validation_prompt } " ) for i , image in enumerate (images )
191+ ]
192+ }
193+ )
194+
195+ del pipeline
196+ torch .cuda .empty_cache ()
197+
198+ return images
199+
200+
143201def import_model_class_from_model_name_or_path (
144202 pretrained_model_name_or_path : str , revision : str , subfolder : str = "text_encoder"
145203):
@@ -862,7 +920,6 @@ def main(args):
862920 if args .report_to == "wandb" :
863921 if not is_wandb_available ():
864922 raise ImportError ("Make sure to install wandb if you want to use it for logging during training." )
865- import wandb
866923
867924 # Make one log on every process with the configuration for debugging.
868925 logging .basicConfig (
@@ -1615,10 +1672,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
16151672
16161673 if accelerator .is_main_process :
16171674 if args .validation_prompt is not None and epoch % args .validation_epochs == 0 :
1618- logger .info (
1619- f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
1620- f" { args .validation_prompt } ."
1621- )
16221675 # create pipeline
16231676 if not args .train_text_encoder :
16241677 text_encoder_one = text_encoder_cls_one .from_pretrained (
@@ -1644,50 +1697,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
16441697 torch_dtype = weight_dtype ,
16451698 )
16461699
1647- # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1648- scheduler_args = {}
1649-
1650- if "variance_type" in pipeline .scheduler .config :
1651- variance_type = pipeline .scheduler .config .variance_type
1652-
1653- if variance_type in ["learned" , "learned_range" ]:
1654- variance_type = "fixed_small"
1655-
1656- scheduler_args ["variance_type" ] = variance_type
1657-
1658- pipeline .scheduler = DPMSolverMultistepScheduler .from_config (
1659- pipeline .scheduler .config , ** scheduler_args
1660- )
1661-
1662- pipeline = pipeline .to (accelerator .device )
1663- pipeline .set_progress_bar_config (disable = True )
1664-
1665- # run inference
1666- generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if args .seed else None
16671700 pipeline_args = {"prompt" : args .validation_prompt }
16681701
1669- with torch .cuda .amp .autocast ():
1670- images = [
1671- pipeline (** pipeline_args , generator = generator ).images [0 ]
1672- for _ in range (args .num_validation_images )
1673- ]
1674-
1675- for tracker in accelerator .trackers :
1676- if tracker .name == "tensorboard" :
1677- np_images = np .stack ([np .asarray (img ) for img in images ])
1678- tracker .writer .add_images ("validation" , np_images , epoch , dataformats = "NHWC" )
1679- if tracker .name == "wandb" :
1680- tracker .log (
1681- {
1682- "validation" : [
1683- wandb .Image (image , caption = f"{ i } : { args .validation_prompt } " )
1684- for i , image in enumerate (images )
1685- ]
1686- }
1687- )
1688-
1689- del pipeline
1690- torch .cuda .empty_cache ()
1702+ images = log_validation (
1703+ pipeline ,
1704+ args ,
1705+ accelerator ,
1706+ pipeline_args ,
1707+ epoch ,
1708+ )
16911709
16921710 # Save the lora layers
16931711 accelerator .wait_for_everyone ()
@@ -1733,45 +1751,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
17331751 torch_dtype = weight_dtype ,
17341752 )
17351753
1736- # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1737- scheduler_args = {}
1738-
1739- if "variance_type" in pipeline .scheduler .config :
1740- variance_type = pipeline .scheduler .config .variance_type
1741-
1742- if variance_type in ["learned" , "learned_range" ]:
1743- variance_type = "fixed_small"
1744-
1745- scheduler_args ["variance_type" ] = variance_type
1746-
1747- pipeline .scheduler = DPMSolverMultistepScheduler .from_config (pipeline .scheduler .config , ** scheduler_args )
1748-
17491754 # load attention processors
17501755 pipeline .load_lora_weights (args .output_dir )
17511756
17521757 # run inference
17531758 images = []
17541759 if args .validation_prompt and args .num_validation_images > 0 :
1755- pipeline = pipeline .to (accelerator .device )
1756- generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if args .seed else None
1757- images = [
1758- pipeline (args .validation_prompt , num_inference_steps = 25 , generator = generator ).images [0 ]
1759- for _ in range (args .num_validation_images )
1760- ]
1761-
1762- for tracker in accelerator .trackers :
1763- if tracker .name == "tensorboard" :
1764- np_images = np .stack ([np .asarray (img ) for img in images ])
1765- tracker .writer .add_images ("test" , np_images , epoch , dataformats = "NHWC" )
1766- if tracker .name == "wandb" :
1767- tracker .log (
1768- {
1769- "test" : [
1770- wandb .Image (image , caption = f"{ i } : { args .validation_prompt } " )
1771- for i , image in enumerate (images )
1772- ]
1773- }
1774- )
1760+ pipeline_args = {"prompt" : args .validation_prompt , "num_inference_steps" : 25 }
1761+ images = log_validation (
1762+ pipeline ,
1763+ args ,
1764+ accelerator ,
1765+ pipeline_args ,
1766+ epoch ,
1767+ final_validation = True ,
1768+ )
17751769
17761770 if args .push_to_hub :
17771771 save_model_card (
0 commit comments