88import torchvision .transforms as transforms
99
1010import diffusers
11- from diffusers import AutoencoderKL , DDIMScheduler
11+ from diffusers import AutoencoderKL , DDIMScheduler , EulerDiscreteScheduler , EulerAncestralDiscreteScheduler , DPMSolverMultistepScheduler , PNDMScheduler
1212
1313from tqdm .auto import tqdm
1414from transformers import CLIPTextModel , CLIPTokenizer
@@ -43,8 +43,8 @@ def main(args):
4343
4444 # create validation pipeline
4545 tokenizer = CLIPTokenizer .from_pretrained (args .pretrained_model_path , subfolder = "tokenizer" )
46- text_encoder = CLIPTextModel .from_pretrained (args .pretrained_model_path , subfolder = "text_encoder" ).cuda ( )
47- vae = AutoencoderKL .from_pretrained (args .pretrained_model_path , subfolder = "vae" ).cuda ( )
46+ text_encoder = CLIPTextModel .from_pretrained (args .pretrained_model_path , subfolder = "text_encoder" ).to ( args . device )
47+ vae = AutoencoderKL .from_pretrained (args .pretrained_model_path , subfolder = "vae" ).to ( args . device )
4848
4949 sample_idx = 0
5050 for model_idx , model_config in enumerate (config ):
@@ -53,13 +53,15 @@ def main(args):
5353 model_config .L = model_config .get ("L" , args .L )
5454
5555 inference_config = OmegaConf .load (model_config .get ("inference_config" , args .inference_config ))
56- unet = UNet3DConditionModel .from_pretrained_2d (args .pretrained_model_path , subfolder = "unet" , unet_additional_kwargs = OmegaConf .to_container (inference_config .unet_additional_kwargs )).cuda ( )
56+ unet = UNet3DConditionModel .from_pretrained_2d (args .pretrained_model_path , subfolder = "unet" , unet_additional_kwargs = OmegaConf .to_container (inference_config .unet_additional_kwargs )).to ( args . device )
5757
5858 # load controlnet model
5959 controlnet = controlnet_images = None
6060 if model_config .get ("controlnet_path" , "" ) != "" :
61- assert model_config .get ("controlnet_images" , "" ) != ""
62- assert model_config .get ("controlnet_config" , "" ) != ""
61+ if not model_config .get ("controlnet_images" , "" ):
62+ raise ValueError ("controlnet_images must be specified when controlnet_path is set" )
63+ if not model_config .get ("controlnet_config" , "" ):
64+ raise ValueError ("controlnet_config must be specified when controlnet_path is set" )
6365
6466 unet .config .num_attention_heads = 8
6567 unet .config .projection_class_embeddings_input_dim = None
@@ -74,14 +76,15 @@ def main(args):
7476 controlnet_state_dict = {name : param for name , param in controlnet_state_dict .items () if "pos_encoder.pe" not in name }
7577 controlnet_state_dict .pop ("animatediff_config" , "" )
7678 controlnet .load_state_dict (controlnet_state_dict )
77- controlnet .cuda ( )
79+ controlnet .to ( args . device )
7880
7981 image_paths = model_config .controlnet_images
8082 if isinstance (image_paths , str ): image_paths = [image_paths ]
8183
8284 print (f"controlnet image paths:" )
8385 for path in image_paths : print (path )
84- assert len (image_paths ) <= model_config .L
86+ if len (image_paths ) > model_config .L :
87+ raise ValueError (f"Number of controlnet images ({ len (image_paths )} ) exceeds video length ({ model_config .L } )" )
8588
8689 image_transforms = transforms .Compose ([
8790 transforms .RandomResizedCrop (
@@ -105,7 +108,7 @@ def image_norm(image):
105108 for i , image in enumerate (controlnet_images ):
106109 Image .fromarray ((255. * (image .numpy ().transpose (1 ,2 ,0 ))).astype (np .uint8 )).save (f"{ savedir } /control_images/{ i } .png" )
107110
108- controlnet_images = torch .stack (controlnet_images ).unsqueeze (0 ).cuda ( )
111+ controlnet_images = torch .stack (controlnet_images ).unsqueeze (0 ).to ( args . device )
109112 controlnet_images = rearrange (controlnet_images , "b f c h w -> b c f h w" )
110113
111114 if controlnet .use_simplified_condition_embedding :
@@ -119,11 +122,22 @@ def image_norm(image):
119122 unet .enable_xformers_memory_efficient_attention ()
120123 if controlnet is not None : controlnet .enable_xformers_memory_efficient_attention ()
121124
125+ scheduler_kwargs = OmegaConf .to_container (inference_config .noise_scheduler_kwargs )
126+ scheduler_map = {
127+ "ddim" : DDIMScheduler ,
128+ "euler" : EulerDiscreteScheduler ,
129+ "euler-a" : EulerAncestralDiscreteScheduler ,
130+ "dpm++" : DPMSolverMultistepScheduler ,
131+ "dpm++-karras" : lambda ** kw : DPMSolverMultistepScheduler (** kw , use_karras_sigmas = True ),
132+ "pndm" : PNDMScheduler ,
133+ }
134+ scheduler = scheduler_map [args .scheduler ](** scheduler_kwargs )
135+
122136 pipeline = AnimationPipeline (
123137 vae = vae , text_encoder = text_encoder , tokenizer = tokenizer , unet = unet ,
124138 controlnet = controlnet ,
125- scheduler = DDIMScheduler ( ** OmegaConf . to_container ( inference_config . noise_scheduler_kwargs )) ,
126- ).to ("cuda" )
139+ scheduler = scheduler ,
140+ ).to (args . device )
127141
128142 pipeline = load_weights (
129143 pipeline ,
@@ -137,7 +151,15 @@ def image_norm(image):
137151 dreambooth_model_path = model_config .get ("dreambooth_path" , "" ),
138152 lora_model_path = model_config .get ("lora_model_path" , "" ),
139153 lora_alpha = model_config .get ("lora_alpha" , 0.8 ),
140- ).to ("cuda" )
154+ ).to (args .device )
155+
156+ # memory optimizations
157+ pipeline .enable_vae_slicing ()
158+ if args .half_precision and args .device != "cpu" :
159+ pipeline .unet .half ()
160+ pipeline .text_encoder .half ()
161+ if controlnet is not None :
162+ controlnet .half ()
141163
142164 prompts = model_config .prompt
143165 n_prompts = list (model_config .n_prompt ) * len (prompts ) if len (model_config .n_prompt ) == 1 else model_config .n_prompt
@@ -194,6 +216,17 @@ def image_norm(image):
194216
195217 parser .add_argument ("--without-xformers" , action = "store_true" )
196218 parser .add_argument ("--format" , type = str , default = "gif" , choices = ["gif" , "mp4" ])
219+ parser .add_argument ("--scheduler" , type = str , default = "ddim" , choices = ["ddim" , "euler" , "euler-a" , "dpm++" , "dpm++-karras" , "pndm" ])
220+ parser .add_argument ("--half-precision" , action = "store_true" , help = "Use float16 for lower VRAM usage" )
221+ parser .add_argument ("--device" , type = str , default = None , help = "Device to use (cuda, mps, cpu). Auto-detected if not specified." )
197222
198223 args = parser .parse_args ()
224+ if args .device is None :
225+ if torch .cuda .is_available ():
226+ args .device = "cuda"
227+ elif hasattr (torch .backends , "mps" ) and torch .backends .mps .is_available ():
228+ args .device = "mps"
229+ else :
230+ args .device = "cpu"
231+ print (f"Using device: { args .device } " )
199232 main (args )
0 commit comments