1717from turbine_models .custom_models .sd_inference import utils
1818from turbine_models .model_runner import vmfbRunner
1919from transformers import CLIPTokenizer
20+ from diffusers import FlowMatchEulerDiscreteScheduler
2021
2122from PIL import Image
2223import os
@@ -44,10 +45,8 @@ class SharkSD3Pipeline:
4445 def __init__ (
4546 self ,
4647 hf_model_name : str ,
47- # scheduler_id: str,
4848 height : int ,
4949 width : int ,
50- shift : float ,
5150 precision : str ,
5251 max_length : int ,
5352 batch_size : int ,
@@ -60,9 +59,12 @@ def __init__(
6059 pipeline_dir : str = "./shark_vmfbs" ,
6160 external_weights_dir : str = "./shark_weights" ,
6261 external_weights : str = "safetensors" ,
63- vae_decomp_attn : bool = True ,
64- custom_vae : str = "" ,
62+ vae_decomp_attn : bool = False ,
6563 cpu_scheduling : bool = False ,
64+ vae_precision : str = "fp32" ,
65+ scheduler_id : str = None , #compatibility only, always uses EulerFlowScheduler
66+ shift : float = 1.0 ,
67+
6668 ):
6769 self .hf_model_name = hf_model_name
6870 # self.scheduler_id = scheduler_id
@@ -120,10 +122,11 @@ def __init__(
120122 self .external_weights_dir = external_weights_dir
121123 self .external_weights = external_weights
122124 self .vae_decomp_attn = vae_decomp_attn
123- self .custom_vae = custom_vae
125+ self .custom_vae = None
124126 self .cpu_scheduling = cpu_scheduling
125127 self .torch_dtype = torch .float32 if self .precision == "fp32" else torch .float16
126- self .vae_dtype = torch .float32
128+ self .vae_precision = vae_precision if vae_precision else self .precision
129+ self .vae_dtype = torch .float32 if vae_precision == "fp32" else torch .float16
127130 # TODO: set this based on user-inputted guidance scale and negative prompt.
128131 self .do_classifier_free_guidance = True # False if any(x in hf_model_name for x in ["turbo", "lightning"]) else True
129132
@@ -206,7 +209,12 @@ def is_prepared(self, vmfbs, weights):
206209 )
207210 if w_key == "clip" :
208211 default_name = os .path .join (
209- self .external_weights_dir , f"sd3_clip_fp16.irpa"
212+ self .external_weights_dir , f"sd3_text_encoders_{ self .precision } .irpa"
213+ )
214+ if w_key == "mmdit" :
215+ default_name = os .path .join (
216+ self .external_weights_dir ,
217+ f"sd3_mmdit_{ self .precision } ." + self .external_weights ,
210218 )
211219 if weights [w_key ] is None and os .path .exists (default_name ):
212220 weights [w_key ] = os .path .join (default_name )
@@ -357,7 +365,7 @@ def export_submodel(
357365 self .batch_size ,
358366 self .height ,
359367 self .width ,
360- "fp32" ,
368+ self . vae_precision ,
361369 "vmfb" ,
362370 self .external_weights ,
363371 vae_external_weight_path ,
@@ -419,10 +427,16 @@ def load_pipeline(
419427 unet_loaded = time .time ()
420428 print ("\n [LOG] MMDiT loaded in " , unet_loaded - load_start , "sec" )
421429
422- runners ["scheduler" ] = sd3_schedulers .SharkSchedulerWrapper (
423- self .devices ["mmdit" ]["driver" ],
424- vmfbs ["scheduler" ],
425- )
430+ if not self .cpu_scheduling :
431+ runners ["scheduler" ] = sd3_schedulers .SharkSchedulerWrapper (
432+ self .devices ["mmdit" ]["driver" ],
433+ vmfbs ["scheduler" ],
434+ )
435+ else :
436+ print ("Using torch CPU scheduler." )
437+ runners ["scheduler" ] = FlowMatchEulerDiscreteScheduler .from_pretrained (
438+ self .hf_model_name , subfolder = "scheduler"
439+ )
426440
427441 sched_loaded = time .time ()
428442 print ("\n [LOG] Scheduler loaded in " , sched_loaded - unet_loaded , "sec" )
@@ -495,11 +509,12 @@ def generate_images(
495509 )
496510 )
497511
498- guidance_scale = ireert .asdevicearray (
499- self .runners ["pipe" ].config .device ,
500- np .asarray ([guidance_scale ]),
501- dtype = iree_dtype ,
502- )
512+ if not self .cpu_scheduling :
513+ guidance_scale = ireert .asdevicearray (
514+ self .runners ["pipe" ].config .device ,
515+ np .asarray ([guidance_scale ]),
516+ dtype = iree_dtype ,
517+ )
503518
504519 tokenize_start = time .time ()
505520 text_input_ids_dict = self .tokenizer .tokenize_with_weights (prompt )
@@ -533,12 +548,23 @@ def generate_images(
533548 "clip"
534549 ].ctx .modules .compiled_text_encoder ["encode_tokens" ](* text_encoders_inputs )
535550 encode_prompts_end = time .time ()
551+ if self .cpu_scheduling :
552+ timesteps , num_inference_steps = sd3_schedulers .retrieve_timesteps (
553+ self .runners ["scheduler" ],
554+ num_inference_steps = self .num_inference_steps ,
555+ timesteps = None ,
556+ )
557+ steps = num_inference_steps
558+
536559
537560 for i in range (batch_count ):
538561 unet_start = time .time ()
539- sample , steps , timesteps = self .runners ["scheduler" ].initialize (samples [i ])
562+ if not self .cpu_scheduling :
563+ latents , steps , timesteps = self .runners ["scheduler" ].initialize (samples [i ])
564+ else :
565+ latents = torch .tensor (samples [i ].to_host (), dtype = self .torch_dtype )
540566 iree_inputs = [
541- sample ,
567+ latents ,
542568 ireert .asdevicearray (
543569 self .runners ["pipe" ].config .device , prompt_embeds , dtype = iree_dtype
544570 ),
@@ -553,40 +579,71 @@ def generate_images(
553579 # print(f"step {s}")
554580 if self .cpu_scheduling :
555581 step_index = s
582+ t = timesteps [s ]
583+ if self .do_classifier_free_guidance :
584+ latent_model_input = torch .cat ([latents ] * 2 )
585+ timestep = ireert .asdevicearray (
586+ self .runners ["pipe" ].config .device ,
587+ t .expand (latent_model_input .shape [0 ]),
588+ dtype = iree_dtype ,
589+ )
590+ latent_model_input = ireert .asdevicearray (
591+ self .runners ["pipe" ].config .device ,
592+ latent_model_input ,
593+ dtype = iree_dtype ,
594+ )
556595 else :
557596 step_index = ireert .asdevicearray (
558597 self .runners ["scheduler" ].runner .config .device ,
559598 torch .tensor ([s ]),
560599 "int64" ,
561600 )
562- latents , t = self .runners ["scheduler" ].prep (
563- sample ,
564- step_index ,
565- timesteps ,
566- )
601+ latent_model_input , timestep = self .runners ["scheduler" ].prep (
602+ latents ,
603+ step_index ,
604+ timesteps ,
605+ )
606+ t = ireert .asdevicearray (
607+ self .runners ["scheduler" ].runner .config .device ,
608+ timestep .to_host ()[0 ]
609+ )
567610 noise_pred = self .runners ["pipe" ].ctx .modules .compiled_mmdit [
568611 "run_forward"
569612 ](
570- latents ,
613+ latent_model_input ,
571614 iree_inputs [1 ],
572615 iree_inputs [2 ],
573- t ,
616+ timestep ,
574617 )
575- sample = self .runners ["scheduler" ].step (
576- noise_pred ,
577- t ,
578- sample ,
579- guidance_scale ,
580- step_index ,
581- )
582- if isinstance (sample , torch .Tensor ):
618+ if not self .cpu_scheduling :
619+ latents = self .runners ["scheduler" ].step (
620+ noise_pred ,
621+ t ,
622+ latents ,
623+ guidance_scale ,
624+ step_index ,
625+ )
626+ else :
627+ noise_pred = torch .tensor (noise_pred .to_host (), dtype = self .torch_dtype )
628+ if self .do_classifier_free_guidance :
629+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
630+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
631+ latents = self .runners ["scheduler" ].step (
632+ noise_pred ,
633+ t ,
634+ latents ,
635+ return_dict = False ,
636+ )[0 ]
637+
638+ if isinstance (latents , torch .Tensor ):
639+ latents = latents .type (self .vae_dtype )
583640 latents = ireert .asdevicearray (
584641 self .runners ["vae" ].config .device ,
585- sample ,
586- dtype = self .vae_dtype ,
642+ latents ,
587643 )
588644 else :
589- latents = sample .astype ("float32" )
645+ vae_numpy_dtype = np .float32 if self .vae_precision == "fp32" else np .float16
646+ latents = latents .astype (vae_numpy_dtype )
590647
591648 vae_start = time .time ()
592649 vae_out = self .runners ["vae" ].ctx .modules .compiled_vae ["decode" ](latents )
@@ -634,7 +691,7 @@ def generate_images(
634691 out_image = Image .fromarray (image )
635692 images .extend ([[out_image ]])
636693 if return_imgs :
637- return images
694+ return images [ 0 ]
638695 for idx_batch , image_batch in enumerate (images ):
639696 for idx , image in enumerate (image_batch ):
640697 img_path = (
@@ -767,7 +824,6 @@ def run_diffusers_cpu(
767824 args .hf_model_name ,
768825 args .height ,
769826 args .width ,
770- args .shift ,
771827 args .precision ,
772828 args .max_length ,
773829 args .batch_size ,
@@ -779,16 +835,15 @@ def run_diffusers_cpu(
779835 args .decomp_attn ,
780836 args .pipeline_dir ,
781837 args .external_weights_dir ,
782- args .external_weights ,
783- args .vae_decomp_attn ,
784- custom_vae = None ,
838+ external_weights = args .external_weights ,
839+ vae_decomp_attn = args .vae_decomp_attn ,
785840 cpu_scheduling = args .cpu_scheduling ,
786841 vae_precision = args .vae_precision ,
787842 )
788- vmfbs , weights = sd3_pipe .check_prepared (mlirs , vmfbs , weights )
789843 if args .cpu_scheduling :
790844 vmfbs .pop ("scheduler" )
791845 weights .pop ("scheduler" )
846+ vmfbs , weights = sd3_pipe .check_prepared (mlirs , vmfbs , weights )
792847 if args .npu_delegate_path :
793848 extra_device_args = {"npu_delegate_path" : args .npu_delegate_path }
794849 else :
0 commit comments