118118 "decomp_attn" : None ,
119119 },
120120 },
121- "unetloop" : {
122- "module_name" : "sdxl_compiled_pipeline" ,
123- "load" : False ,
124- "keywords" : ["unetloop" ],
125- "wraps" : ["unet" , "scheduler" ],
126- "export_args" : {
127- "batch_size" : 1 ,
128- "height" : 1024 ,
129- "width" : 1024 ,
130- "max_length" : 64 ,
131- },
132- },
133121 "fullpipeline" : {
134122 "module_name" : "sdxl_compiled_pipeline" ,
135- "load" : False ,
123+ "load" : True ,
136124 "keywords" : ["fullpipeline" ],
137- "wraps" : ["text_encoder" , " unet" , "scheduler" , "vae" ],
125+ "wraps" : ["unet" , "scheduler" , "vae" ],
138126 "export_args" : {
139127 "batch_size" : 1 ,
140128 "height" : 1024 ,
@@ -234,6 +222,7 @@ def __init__(
234222 benchmark : bool | dict [bool ] = False ,
235223 verbose : bool = False ,
236224 batch_prompts : bool = False ,
225+ compiled_pipeline : bool = False ,
237226 ):
238227 common_export_args = {
239228 "hf_model_name" : None ,
@@ -312,6 +301,7 @@ def __init__(
312301 self .scheduler = None
313302
314303 self .split_scheduler = True
304+ self .compiled_pipeline = compiled_pipeline
315305
316306 self .base_model_name = (
317307 hf_model_name
@@ -322,11 +312,6 @@ def __init__(
322312 self .is_sdxl = "xl" in self .base_model_name .lower ()
323313 self .is_sd3 = "stable-diffusion-3" in self .base_model_name
324314 if self .is_sdxl :
325- if self .split_scheduler :
326- if self .map .get ("unetloop" ):
327- self .map .pop ("unetloop" )
328- if self .map .get ("fullpipeline" ):
329- self .map .pop ("fullpipeline" )
330315 self .tokenizers = [
331316 CLIPTokenizer .from_pretrained (
332317 self .base_model_name , subfolder = "tokenizer"
@@ -340,6 +325,20 @@ def __init__(
340325 self .scheduler_device = self .map ["unet" ]["device" ]
341326 self .scheduler_driver = self .map ["unet" ]["driver" ]
342327 self .scheduler_target = self .map ["unet" ]["target" ]
328+ if not self .compiled_pipeline :
329+ if self .map .get ("unetloop" ):
330+ self .map .pop ("unetloop" )
331+ if self .map .get ("fullpipeline" ):
332+ self .map .pop ("fullpipeline" )
333+ elif self .compiled_pipeline :
334+ self .map ["unet" ]["load" ] = False
335+ self .map ["vae" ]["load" ] = False
336+ self .load_scheduler (
337+ scheduler_id ,
338+ num_inference_steps ,
339+ )
340+ self .map ["scheduler" ]["runner" ].unload ()
341+ self .map ["scheduler" ]["load" ] = False
343342 elif not self .is_sd3 :
344343 self .tokenizer = CLIPTokenizer .from_pretrained (
345344 self .base_model_name , subfolder = "tokenizer"
@@ -381,10 +380,6 @@ def load_scheduler(
381380 scheduler_id : str ,
382381 steps : int = 30 ,
383382 ):
384- if self .is_sd3 :
385- scheduler_device = self .mmdit .device
386- else :
387- scheduler_device = self .unet .device
388383 if not self .cpu_scheduling :
389384 self .map ["scheduler" ] = {
390385 "module_name" : "compiled_scheduler" ,
@@ -430,7 +425,11 @@ def load_scheduler(
430425 except :
431426 print ("JIT export of scheduler failed. Loading CPU scheduler." )
432427 self .cpu_scheduling = True
433- if self .cpu_scheduling :
428+ elif self .cpu_scheduling :
429+ if self .is_sd3 :
430+ scheduler_device = self .mmdit .device
431+ else :
432+ scheduler_device = self .unet .device
434433 scheduler = schedulers .get_scheduler (self .base_model_name , scheduler_id )
435434 self .scheduler = schedulers .SharkSchedulerCPUWrapper (
436435 scheduler ,
@@ -615,6 +614,72 @@ def _produce_latents_sdxl(
615614 latents = self .scheduler ("run_step" , [noise_pred , t , latents ])
616615 return latents
617616
617+ def produce_images_compiled (
618+ sample ,
619+ prompt_embeds ,
620+ text_embeds ,
621+ guidance_scale ,
622+ ):
623+ pipe_inputs = [
624+ sample ,
625+ prompt_embeds ,
626+ text_embeds ,
627+ guidance_scale ,
628+ ]
629+ image = self .compiled_pipeline ("produce_img_latents" , pipe_inputs )
630+
631+ def prepare_sampling_inputs (
632+ self ,
633+ prompt : str ,
634+ negative_prompt : str = "" ,
635+ steps : int = 30 ,
636+ batch_count : int = 1 ,
637+ guidance_scale : float = 7.5 ,
638+ seed : float = - 1 ,
639+ cpu_scheduling : bool = True ,
640+ scheduler_id : str = "EulerDiscrete" ,
641+ return_imgs : bool = False ,
642+ ):
643+ needs_new_scheduler = (
644+ (steps and steps != self .num_inference_steps )
645+ or (cpu_scheduling != self .cpu_scheduling )
646+ and self .split_scheduler
647+ )
648+ if not self .scheduler and not self .compiled_pipeline :
649+ needs_new_scheduler = True
650+
651+ if guidance_scale == 0 :
652+ negative_prompt = prompt
653+ prompt = ""
654+
655+ self .cpu_scheduling = cpu_scheduling
656+ if steps and needs_new_scheduler :
657+ self .num_inference_steps = steps
658+ self .load_scheduler (scheduler_id , steps )
659+
660+ pipe_start = time .time ()
661+ numpy_images = []
662+
663+ samples = self .get_rand_latents (seed , batch_count )
664+
665+ # Tokenize prompt and negative prompt.
666+ if self .is_sdxl :
667+ prompt_embeds , negative_embeds = self .encode_prompts_sdxl (
668+ prompt , negative_prompt
669+ )
670+ else :
671+ prompt_embeds , negative_embeds = encode_prompt (
672+ self , prompt , negative_prompt
673+ )
674+ produce_latents_input = [
675+ samples [0 ],
676+ prompt_embeds ,
677+ negative_embeds ,
678+ steps ,
679+ guidance_scale ,
680+ ]
681+ return produce_latents_input
682+
618683 def generate_images (
619684 self ,
620685 prompt : str ,
@@ -660,18 +725,26 @@ def generate_images(
660725 )
661726
662727 for i in range (batch_count ):
663- produce_latents_input = [
664- samples [i ],
665- prompt_embeds ,
666- negative_embeds ,
667- steps ,
668- guidance_scale ,
669- ]
670- if self .is_sdxl :
671- latents = self ._produce_latents_sdxl (* produce_latents_input )
728+ if self .compiled_pipeline :
729+ image = produce_images_compiled (
730+ samples [i ],
731+ prompt_embeds ,
732+ negative_embeds ,
733+ guidance_scale
734+ )
672735 else :
673- latents = self ._produce_latents_sd (* produce_latents_input )
674- image = self .vae ("decode" , [latents ])
736+ produce_latents_input = [
737+ samples [i ],
738+ prompt_embeds ,
739+ negative_embeds ,
740+ steps ,
741+ guidance_scale ,
742+ ]
743+ if self .is_sdxl :
744+ latents = self ._produce_latents_sdxl (* produce_latents_input )
745+ else :
746+ latents = self ._produce_latents_sd (* produce_latents_input )
747+ image = self .vae ("decode" , [latents ])
675748 numpy_images .append (image )
676749 pipe_end = time .time ()
677750
@@ -757,6 +830,8 @@ def numpy_to_pil_image(images):
757830 args .use_i8_punet ,
758831 benchmark ,
759832 args .verbose ,
833+ False ,
834+ args .compiled_pipeline ,
760835 )
761836 sd_pipe .prepare_all ()
762837 sd_pipe .load_map ()
0 commit comments