118118 "decomp_attn" : None ,
119119 },
120120 },
121- "unetloop" : {
122- "module_name" : "sdxl_compiled_pipeline" ,
123- "load" : False ,
124- "keywords" : ["unetloop" ],
125- "wraps" : ["unet" , "scheduler" ],
126- "export_args" : {
127- "batch_size" : 1 ,
128- "height" : 1024 ,
129- "width" : 1024 ,
130- "max_length" : 64 ,
131- },
132- },
133121 "fullpipeline" : {
134122 "module_name" : "sdxl_compiled_pipeline" ,
135- "load" : False ,
123+ "load" : True ,
136124 "keywords" : ["fullpipeline" ],
137- "wraps" : ["text_encoder" , " unet" , "scheduler" , "vae" ],
125+ "wraps" : ["unet" , "scheduler" , "vae" ],
138126 "export_args" : {
139127 "batch_size" : 1 ,
140128 "height" : 1024 ,
@@ -234,6 +222,7 @@ def __init__(
234222 benchmark : bool | dict [bool ] = False ,
235223 verbose : bool = False ,
236224 batch_prompts : bool = False ,
225+ compiled_pipeline : bool = False ,
237226 ):
238227 common_export_args = {
239228 "hf_model_name" : None ,
@@ -312,6 +301,7 @@ def __init__(
312301 self .scheduler = None
313302
314303 self .split_scheduler = True
304+ self .compiled_pipeline = compiled_pipeline
315305
316306 self .base_model_name = (
317307 hf_model_name
@@ -322,11 +312,6 @@ def __init__(
322312 self .is_sdxl = "xl" in self .base_model_name .lower ()
323313 self .is_sd3 = "stable-diffusion-3" in self .base_model_name
324314 if self .is_sdxl :
325- if self .split_scheduler :
326- if self .map .get ("unetloop" ):
327- self .map .pop ("unetloop" )
328- if self .map .get ("fullpipeline" ):
329- self .map .pop ("fullpipeline" )
330315 self .tokenizers = [
331316 CLIPTokenizer .from_pretrained (
332317 self .base_model_name , subfolder = "tokenizer"
@@ -340,6 +325,20 @@ def __init__(
340325 self .scheduler_device = self .map ["unet" ]["device" ]
341326 self .scheduler_driver = self .map ["unet" ]["driver" ]
342327 self .scheduler_target = self .map ["unet" ]["target" ]
328+ if not self .compiled_pipeline :
329+ if self .map .get ("unetloop" ):
330+ self .map .pop ("unetloop" )
331+ if self .map .get ("fullpipeline" ):
332+ self .map .pop ("fullpipeline" )
333+ elif self .compiled_pipeline :
334+ self .map ["unet" ]["load" ] = False
335+ self .map ["vae" ]["load" ] = False
336+ self .load_scheduler (
337+ scheduler_id ,
338+ num_inference_steps ,
339+ )
340+ self .map ["scheduler" ]["runner" ].unload ()
341+ self .map ["scheduler" ]["load" ] = False
343342 elif not self .is_sd3 :
344343 self .tokenizer = CLIPTokenizer .from_pretrained (
345344 self .base_model_name , subfolder = "tokenizer"
@@ -381,10 +380,6 @@ def load_scheduler(
381380 scheduler_id : str ,
382381 steps : int = 30 ,
383382 ):
384- if self .is_sd3 :
385- scheduler_device = self .mmdit .device
386- else :
387- scheduler_device = self .unet .device
388383 if not self .cpu_scheduling :
389384 self .map ["scheduler" ] = {
390385 "module_name" : "compiled_scheduler" ,
@@ -430,7 +425,11 @@ def load_scheduler(
430425 except :
431426 print ("JIT export of scheduler failed. Loading CPU scheduler." )
432427 self .cpu_scheduling = True
433- if self .cpu_scheduling :
428+ elif self .cpu_scheduling :
429+ if self .is_sd3 :
430+ scheduler_device = self .mmdit .device
431+ else :
432+ scheduler_device = self .unet .device
434433 scheduler = schedulers .get_scheduler (self .base_model_name , scheduler_id )
435434 self .scheduler = schedulers .SharkSchedulerCPUWrapper (
436435 scheduler ,
@@ -466,13 +465,10 @@ def encode_prompts_sdxl(self, prompt, negative_prompt):
466465 text_input_ids_list += text_inputs .input_ids .unsqueeze (0 )
467466 uncond_input_ids_list += uncond_input .input_ids .unsqueeze (0 )
468467
469- if self .compiled_pipeline :
470- return text_input_ids_list , uncond_input_ids_list
471- else :
472- prompt_embeds , add_text_embeds = self .text_encoder (
473- "encode_prompts" , [* text_input_ids_list , * uncond_input_ids_list ]
474- )
475- return prompt_embeds , add_text_embeds
468+ prompt_embeds , add_text_embeds = self .text_encoder (
469+ "encode_prompts" , [* text_input_ids_list , * uncond_input_ids_list ]
470+ )
471+ return prompt_embeds , add_text_embeds
476472
477473 def prepare_latents (
478474 self ,
@@ -615,6 +611,75 @@ def _produce_latents_sdxl(
615611 latents = self .scheduler ("run_step" , [noise_pred , t , latents ])
616612 return latents
617613
614+ def produce_images_compiled (
615+ self ,
616+ sample ,
617+ prompt_embeds ,
618+ text_embeds ,
619+ guidance_scale ,
620+ ):
621+ pipe_inputs = [
622+ sample ,
623+ prompt_embeds ,
624+ text_embeds ,
625+ torch .as_tensor ([guidance_scale ], dtype = sample .dtype ),
626+ ]
627+ #image = self.compiled_pipeline("produce_img_latents", pipe_inputs)
628+ image = self .map ["fullpipeline" ]["runner" ]("produce_image_latents" , pipe_inputs )
629+ return image
630+
631+ def prepare_sampling_inputs (
632+ self ,
633+ prompt : str ,
634+ negative_prompt : str = "" ,
635+ steps : int = 30 ,
636+ batch_count : int = 1 ,
637+ guidance_scale : float = 7.5 ,
638+ seed : float = - 1 ,
639+ cpu_scheduling : bool = True ,
640+ scheduler_id : str = "EulerDiscrete" ,
641+ return_imgs : bool = False ,
642+ ):
643+ needs_new_scheduler = (
644+ (steps and steps != self .num_inference_steps )
645+ or (cpu_scheduling != self .cpu_scheduling )
646+ and self .split_scheduler
647+ )
648+ if not self .scheduler and not self .compiled_pipeline :
649+ needs_new_scheduler = True
650+
651+ if guidance_scale == 0 :
652+ negative_prompt = prompt
653+ prompt = ""
654+
655+ self .cpu_scheduling = cpu_scheduling
656+ if steps and needs_new_scheduler :
657+ self .num_inference_steps = steps
658+ self .load_scheduler (scheduler_id , steps )
659+
660+ pipe_start = time .time ()
661+ numpy_images = []
662+
663+ samples = self .get_rand_latents (seed , batch_count )
664+
665+ # Tokenize prompt and negative prompt.
666+ if self .is_sdxl :
667+ prompt_embeds , negative_embeds = self .encode_prompts_sdxl (
668+ prompt , negative_prompt
669+ )
670+ else :
671+ prompt_embeds , negative_embeds = encode_prompt (
672+ self , prompt , negative_prompt
673+ )
674+ produce_latents_input = [
675+ samples [0 ],
676+ prompt_embeds ,
677+ negative_embeds ,
678+ steps ,
679+ guidance_scale ,
680+ ]
681+ return produce_latents_input
682+
618683 def generate_images (
619684 self ,
620685 prompt : str ,
@@ -660,18 +725,21 @@ def generate_images(
660725 )
661726
662727 for i in range (batch_count ):
663- produce_latents_input = [
664- samples [i ],
665- prompt_embeds ,
666- negative_embeds ,
667- steps ,
668- guidance_scale ,
669- ]
670- if self .is_sdxl :
671- latents = self ._produce_latents_sdxl (* produce_latents_input )
728+ if self .compiled_pipeline :
729+ image = self .produce_images_compiled (samples [i ], prompt_embeds , negative_embeds , guidance_scale ).to_host ()
672730 else :
673- latents = self ._produce_latents_sd (* produce_latents_input )
674- image = self .vae ("decode" , [latents ])
731+ produce_latents_input = [
732+ samples [i ],
733+ prompt_embeds ,
734+ negative_embeds ,
735+ steps ,
736+ guidance_scale ,
737+ ]
738+ if self .is_sdxl :
739+ latents = self ._produce_latents_sdxl (* produce_latents_input )
740+ else :
741+ latents = self ._produce_latents_sd (* produce_latents_input )
742+ image = self .vae ("decode" , [latents ])
675743 numpy_images .append (image )
676744 pipe_end = time .time ()
677745
@@ -757,8 +825,10 @@ def numpy_to_pil_image(images):
757825 args .use_i8_punet ,
758826 benchmark ,
759827 args .verbose ,
828+ False ,
829+ args .compiled_pipeline ,
760830 )
761- sd_pipe .prepare_all ()
831+ sd_pipe .prepare_all (num_steps = args . num_inference_steps )
762832 sd_pipe .load_map ()
763833 sd_pipe .generate_images (
764834 args .prompt ,
0 commit comments