118118 "decomp_attn" : None ,
119119 },
120120 },
121- "unetloop" : {
122- "module_name" : "sdxl_compiled_pipeline" ,
123- "load" : False ,
124- "keywords" : ["unetloop" ],
125- "wraps" : ["unet" , "scheduler" ],
126- "export_args" : {
127- "batch_size" : 1 ,
128- "height" : 1024 ,
129- "width" : 1024 ,
130- "max_length" : 64 ,
131- },
132- },
133121 "fullpipeline" : {
134122 "module_name" : "sdxl_compiled_pipeline" ,
135- "load" : False ,
123+ "load" : True ,
136124 "keywords" : ["fullpipeline" ],
137- "wraps" : ["text_encoder" , " unet" , "scheduler" , "vae" ],
125+ "wraps" : ["unet" , "scheduler" , "vae" ],
138126 "export_args" : {
139127 "batch_size" : 1 ,
140128 "height" : 1024 ,
@@ -190,6 +178,7 @@ def get_sd_model_map(hf_model_name):
190178 "stabilityai/sdxl-turbo" ,
191179 "stabilityai/stable-diffusion-xl-base-1.0" ,
192180 "/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16/checkpoint_pipe" ,
181+ "/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16//checkpoint_pipe" ,
193182 ]:
194183 return sdxl_model_map
195184 elif "stabilityai/stable-diffusion-3" in name :
@@ -233,6 +222,7 @@ def __init__(
233222 benchmark : bool | dict [bool ] = False ,
234223 verbose : bool = False ,
235224 batch_prompts : bool = False ,
225+ compiled_pipeline : bool = False ,
236226 ):
237227 common_export_args = {
238228 "hf_model_name" : None ,
@@ -243,11 +233,11 @@ def __init__(
243233 "exit_on_vmfb" : False ,
244234 "pipeline_dir" : pipeline_dir ,
245235 "input_mlir" : None ,
246- "attn_spec" : None ,
236+ "attn_spec" : attn_spec ,
247237 "external_weights" : None ,
248238 "external_weight_path" : None ,
249239 }
250- sd_model_map = get_sd_model_map (hf_model_name )
240+ sd_model_map = copy . deepcopy ( get_sd_model_map (hf_model_name ) )
251241 for submodel in sd_model_map :
252242 if "load" not in sd_model_map [submodel ]:
253243 sd_model_map [submodel ]["load" ] = True
@@ -311,6 +301,7 @@ def __init__(
311301 self .scheduler = None
312302
313303 self .split_scheduler = True
304+ self .compiled_pipeline = compiled_pipeline
314305
315306 self .base_model_name = (
316307 hf_model_name
@@ -321,11 +312,6 @@ def __init__(
321312 self .is_sdxl = "xl" in self .base_model_name .lower ()
322313 self .is_sd3 = "stable-diffusion-3" in self .base_model_name
323314 if self .is_sdxl :
324- if self .split_scheduler :
325- if self .map .get ("unetloop" ):
326- self .map .pop ("unetloop" )
327- if self .map .get ("fullpipeline" ):
328- self .map .pop ("fullpipeline" )
329315 self .tokenizers = [
330316 CLIPTokenizer .from_pretrained (
331317 self .base_model_name , subfolder = "tokenizer"
@@ -339,6 +325,20 @@ def __init__(
339325 self .scheduler_device = self .map ["unet" ]["device" ]
340326 self .scheduler_driver = self .map ["unet" ]["driver" ]
341327 self .scheduler_target = self .map ["unet" ]["target" ]
328+ if not self .compiled_pipeline :
329+ if self .map .get ("unetloop" ):
330+ self .map .pop ("unetloop" )
331+ if self .map .get ("fullpipeline" ):
332+ self .map .pop ("fullpipeline" )
333+ elif self .compiled_pipeline :
334+ self .map ["unet" ]["load" ] = False
335+ self .map ["vae" ]["load" ] = False
336+ self .load_scheduler (
337+ scheduler_id ,
338+ num_inference_steps ,
339+ )
340+ self .map ["scheduler" ]["runner" ].unload ()
341+ self .map ["scheduler" ]["load" ] = False
342342 elif not self .is_sd3 :
343343 self .tokenizer = CLIPTokenizer .from_pretrained (
344344 self .base_model_name , subfolder = "tokenizer"
@@ -351,23 +351,27 @@ def __init__(
351351
352352 self .latents_dtype = torch_dtypes [self .latents_precision ]
353353 self .use_i8_punet = self .use_punet = use_i8_punet
354+ if self .use_punet :
355+ self .setup_punet ()
356+ else :
357+ self .map ["unet" ]["keywords" ].append ("!punet" )
358+ self .map ["unet" ]["function_name" ] = "run_forward"
359+
360+ def setup_punet (self ):
354361 if self .use_i8_punet :
355362 self .map ["unet" ]["export_args" ]["precision" ] = "i8"
356- self .map ["unet" ]["export_args" ]["use_punet" ] = True
357- self .map ["unet" ]["use_weights_for_export" ] = True
358- self .map ["unet" ]["keywords" ].append ("punet" )
359- self .map ["unet" ]["module_name" ] = "compiled_punet"
360- self .map ["unet" ]["function_name" ] = "main"
361363 self .map ["unet" ]["export_args" ]["external_weight_path" ] = (
362364 utils .create_safe_name (self .base_model_name ) + "_punet_dataset_i8.irpa"
363365 )
364366 for idx , word in enumerate (self .map ["unet" ]["keywords" ]):
365367 if word in ["fp32" , "fp16" ]:
366368 self .map ["unet" ]["keywords" ][idx ] = "i8"
367369 break
368- else :
369- self .map ["unet" ]["keywords" ].append ("!punet" )
370- self .map ["unet" ]["function_name" ] = "run_forward"
370+ self .map ["unet" ]["export_args" ]["use_punet" ] = True
371+ self .map ["unet" ]["use_weights_for_export" ] = True
372+ self .map ["unet" ]["keywords" ].append ("punet" )
373+ self .map ["unet" ]["module_name" ] = "compiled_punet"
374+ self .map ["unet" ]["function_name" ] = "main"
371375
372376 # LOAD
373377
@@ -376,10 +380,6 @@ def load_scheduler(
376380 scheduler_id : str ,
377381 steps : int = 30 ,
378382 ):
379- if self .is_sd3 :
380- scheduler_device = self .mmdit .device
381- else :
382- scheduler_device = self .unet .device
383383 if not self .cpu_scheduling :
384384 self .map ["scheduler" ] = {
385385 "module_name" : "compiled_scheduler" ,
@@ -425,7 +425,11 @@ def load_scheduler(
425425 except :
426426 print ("JIT export of scheduler failed. Loading CPU scheduler." )
427427 self .cpu_scheduling = True
428- if self .cpu_scheduling :
428+ elif self .cpu_scheduling :
429+ if self .is_sd3 :
430+ scheduler_device = self .mmdit .device
431+ else :
432+ scheduler_device = self .unet .device
429433 scheduler = schedulers .get_scheduler (self .base_model_name , scheduler_id )
430434 self .scheduler = schedulers .SharkSchedulerCPUWrapper (
431435 scheduler ,
@@ -461,13 +465,10 @@ def encode_prompts_sdxl(self, prompt, negative_prompt):
461465 text_input_ids_list += text_inputs .input_ids .unsqueeze (0 )
462466 uncond_input_ids_list += uncond_input .input_ids .unsqueeze (0 )
463467
464- if self .compiled_pipeline :
465- return text_input_ids_list , uncond_input_ids_list
466- else :
467- prompt_embeds , add_text_embeds = self .text_encoder (
468- "encode_prompts" , [* text_input_ids_list , * uncond_input_ids_list ]
469- )
470- return prompt_embeds , add_text_embeds
468+ prompt_embeds , add_text_embeds = self .text_encoder (
469+ "encode_prompts" , [* text_input_ids_list , * uncond_input_ids_list ]
470+ )
471+ return prompt_embeds , add_text_embeds
471472
472473 def prepare_latents (
473474 self ,
@@ -565,9 +566,11 @@ def _produce_latents_sdxl(
565566 [guidance_scale ],
566567 dtype = self .map ["unet" ]["np_dtype" ],
567568 )
569+ # Disable progress bar if we aren't in verbose mode or if we're printing
570+ # benchmark latencies for unet.
568571 for i , t in tqdm (
569572 enumerate (timesteps ),
570- disable = (self .map ["unet" ].get ("benchmark" ) and self .verbose ),
573+ disable = (self .map ["unet" ].get ("benchmark" ) or not self .verbose ),
571574 ):
572575 if self .cpu_scheduling :
573576 latent_model_input , t = self .scheduler .scale_model_input (
@@ -608,6 +611,75 @@ def _produce_latents_sdxl(
608611 latents = self .scheduler ("run_step" , [noise_pred , t , latents ])
609612 return latents
610613
614+ def produce_images_compiled (
615+ self ,
616+ sample ,
617+ prompt_embeds ,
618+ text_embeds ,
619+ guidance_scale ,
620+ ):
621+ pipe_inputs = [
622+ sample ,
623+ prompt_embeds ,
624+ text_embeds ,
625+ torch .as_tensor ([guidance_scale ], dtype = sample .dtype ),
626+ ]
627+ # image = self.compiled_pipeline("produce_img_latents", pipe_inputs)
628+ image = self .map ["fullpipeline" ]["runner" ]("produce_image_latents" , pipe_inputs )
629+ return image
630+
631+ def prepare_sampling_inputs (
632+ self ,
633+ prompt : str ,
634+ negative_prompt : str = "" ,
635+ steps : int = 30 ,
636+ batch_count : int = 1 ,
637+ guidance_scale : float = 7.5 ,
638+ seed : float = - 1 ,
639+ cpu_scheduling : bool = True ,
640+ scheduler_id : str = "EulerDiscrete" ,
641+ return_imgs : bool = False ,
642+ ):
643+ needs_new_scheduler = (
644+ (steps and steps != self .num_inference_steps )
645+ or (cpu_scheduling != self .cpu_scheduling )
646+ and self .split_scheduler
647+ )
648+ if not self .scheduler and not self .compiled_pipeline :
649+ needs_new_scheduler = True
650+
651+ if guidance_scale == 0 :
652+ negative_prompt = prompt
653+ prompt = ""
654+
655+ self .cpu_scheduling = cpu_scheduling
656+ if steps and needs_new_scheduler :
657+ self .num_inference_steps = steps
658+ self .load_scheduler (scheduler_id , steps )
659+
660+ pipe_start = time .time ()
661+ numpy_images = []
662+
663+ samples = self .get_rand_latents (seed , batch_count )
664+
665+ # Tokenize prompt and negative prompt.
666+ if self .is_sdxl :
667+ prompt_embeds , negative_embeds = self .encode_prompts_sdxl (
668+ prompt , negative_prompt
669+ )
670+ else :
671+ prompt_embeds , negative_embeds = encode_prompt (
672+ self , prompt , negative_prompt
673+ )
674+ produce_latents_input = [
675+ samples [0 ],
676+ prompt_embeds ,
677+ negative_embeds ,
678+ steps ,
679+ guidance_scale ,
680+ ]
681+ return produce_latents_input
682+
611683 def generate_images (
612684 self ,
613685 prompt : str ,
@@ -653,18 +725,23 @@ def generate_images(
653725 )
654726
655727 for i in range (batch_count ):
656- produce_latents_input = [
657- samples [i ],
658- prompt_embeds ,
659- negative_embeds ,
660- steps ,
661- guidance_scale ,
662- ]
663- if self .is_sdxl :
664- latents = self ._produce_latents_sdxl (* produce_latents_input )
728+ if self .compiled_pipeline :
729+ image = self .produce_images_compiled (
730+ samples [i ], prompt_embeds , negative_embeds , guidance_scale
731+ ).to_host ()
665732 else :
666- latents = self ._produce_latents_sd (* produce_latents_input )
667- image = self .vae ("decode" , [latents ])
733+ produce_latents_input = [
734+ samples [i ],
735+ prompt_embeds ,
736+ negative_embeds ,
737+ steps ,
738+ guidance_scale ,
739+ ]
740+ if self .is_sdxl :
741+ latents = self ._produce_latents_sdxl (* produce_latents_input )
742+ else :
743+ latents = self ._produce_latents_sd (* produce_latents_input )
744+ image = self .vae ("decode" , [latents ])
668745 numpy_images .append (image )
669746 pipe_end = time .time ()
670747
@@ -750,8 +827,10 @@ def numpy_to_pil_image(images):
750827 args .use_i8_punet ,
751828 benchmark ,
752829 args .verbose ,
830+ False ,
831+ args .compiled_pipeline ,
753832 )
754- sd_pipe .prepare_all ()
833+ sd_pipe .prepare_all (num_steps = args . num_inference_steps )
755834 sd_pipe .load_map ()
756835 sd_pipe .generate_images (
757836 args .prompt ,
0 commit comments