2525import copy
2626from datetime import datetime as dt
2727
28- device_list = [
29- "cpu" ,
30- "vulkan" ,
31- "cuda" ,
32- "rocm" ,
33- ]
34-
35- rt_device_list = [
36- "local-task" ,
37- "local-sync" ,
38- "vulkan" ,
39- "cuda" ,
40- "rocm" ,
41- "hip" ,
42- ]
43-
4428empty_pipe_dict = {
45- "vae" : None ,
46- "text_encoders" : None ,
29+ "clip" : None ,
4730 "mmdit" : None ,
4831 "scheduler" : None ,
32+ "vae" : None ,
4933}
5034
5135EMPTY_FLAGS = {
@@ -90,24 +74,40 @@ def __init__(
9074 self .batch_size = batch_size
9175 self .num_inference_steps = num_inference_steps
9276 self .devices = {}
93- if isinstance (self . device , dict ):
77+ if isinstance (device , dict ):
9478 assert isinstance (iree_target_triple , dict ), "Device and target triple must be both dicts or both strings."
9579 self .devices ["clip" ] = {
9680 "device" : device ["clip" ],
81+ "driver" : utils .iree_device_map (device ["clip" ]),
9782 "target" : iree_target_triple ["clip" ]
9883 }
9984 self .devices ["mmdit" ] = {
10085 "device" : device ["mmdit" ],
86+ "driver" : utils .iree_device_map (device ["mmdit" ]),
10187 "target" : iree_target_triple ["mmdit" ]
10288 }
10389 self .devices ["vae" ] = {
10490 "device" : device ["vae" ],
91+ "driver" : utils .iree_device_map (device ["vae" ]),
10592 "target" : iree_target_triple ["vae" ]
10693 }
10794 else :
108- self .devices ["clip" ] = device
109- self .devices ["mmdit" ] = device
110- self .devices ["vae" ] = device
95+ assert isinstance (iree_target_triple , str ), "Device and target triple must be both dicts or both strings."
96+ self .devices ["clip" ] = {
97+ "device" : device ,
98+ "driver" : utils .iree_device_map (device ),
99+ "target" : iree_target_triple
100+ }
101+ self .devices ["mmdit" ] = {
102+ "device" : device ,
103+ "driver" : utils .iree_device_map (device ),
104+ "target" : iree_target_triple
105+ }
106+ self .devices ["vae" ] = {
107+ "device" : device ,
108+ "driver" : utils .iree_device_map (device ),
109+ "target" : iree_target_triple
110+ }
111111 self .iree_target_triple = iree_target_triple
112112 self .ireec_flags = ireec_flags if ireec_flags else EMPTY_FLAGS
113113 self .attn_spec = attn_spec
@@ -176,6 +176,9 @@ def is_prepared(self, vmfbs, weights):
176176 val = None
177177 default_filepath = None
178178 continue
179+ elif key == "clip" :
180+ val = "text_encoders"
181+ default_filepath = os .path .join (self .pipeline_dir , val + ".vmfb" )
179182 else :
180183 val = vmfbs [key ]
181184 default_filepath = os .path .join (self .pipeline_dir , key + ".vmfb" )
@@ -197,7 +200,7 @@ def is_prepared(self, vmfbs, weights):
197200 default_name = os .path .join (
198201 self .external_weights_dir , w_key + "." + self .external_weights
199202 )
200- if w_key == "text_encoders " :
203+ if w_key == "clip " :
201204 default_name = os .path .join (
202205 self .external_weights_dir , f"sd3_clip_fp16.irpa"
203206 )
@@ -287,7 +290,7 @@ def export_submodel(
287290 if weights_only :
288291 input_mlir = {
289292 "vae" : None ,
290- "text_encoders " : None ,
293+ "clip " : None ,
291294 "mmdit" : None ,
292295 "scheduler" : None ,
293296 }
@@ -366,7 +369,7 @@ def export_submodel(
366369 )
367370 del vae_torch
368371 return vae_vmfb , vae_external_weight_path
369- case "text_encoders " :
372+ case "clip " :
370373 _ , text_encoders_vmfb = sd3_text_encoders .export_text_encoders (
371374 self .hf_model_name ,
372375 None ,
@@ -380,7 +383,7 @@ def export_submodel(
380383 self .ireec_flags ["clip" ],
381384 exit_on_vmfb = False ,
382385 pipeline_dir = self .pipeline_dir ,
383- input_mlir = input_mlir ["text_encoders " ],
386+ input_mlir = input_mlir ["clip " ],
384387 attn_spec = self .attn_spec ,
385388 output_batchsize = self .batch_size ,
386389 )
@@ -392,7 +395,6 @@ def load_pipeline(
392395 self ,
393396 vmfbs : dict ,
394397 weights : dict ,
395- rt_device : str | dict [str ],
396398 compiled_pipeline : bool = False ,
397399 split_scheduler : bool = True ,
398400 extra_device_args : dict = {},
@@ -401,35 +403,37 @@ def load_pipeline(
401403 delegate = extra_device_args ["npu_delegate_path" ]
402404 else :
403405 delegate = None
406+
404407 self .runners = {}
405408 runners = {}
406409 load_start = time .time ()
407410 runners ["pipe" ] = vmfbRunner (
408- rt_device ,
411+ self . devices [ "mmdit" ][ "driver" ] ,
409412 vmfbs ["mmdit" ],
410413 weights ["mmdit" ],
411414 )
412415 unet_loaded = time .time ()
413416 print ("\n [LOG] MMDiT loaded in " , unet_loaded - load_start , "sec" )
414417
415418 runners ["scheduler" ] = sd3_schedulers .SharkSchedulerWrapper (
416- rt_device ,
419+ self . devices [ "mmdit" ][ "driver" ] ,
417420 vmfbs ["scheduler" ],
418421 )
419422
420423 sched_loaded = time .time ()
421424 print ("\n [LOG] Scheduler loaded in " , sched_loaded - unet_loaded , "sec" )
422425 runners ["vae" ] = vmfbRunner (
423- rt_device ,
426+ self . devices [ "vae" ][ "driver" ] ,
424427 vmfbs ["vae" ],
425- weights ["vae" ],
428+ weights ["vae" ],
429+ extra_plugin = delegate ,
426430 )
427431 vae_loaded = time .time ()
428432 print ("\n [LOG] VAE Decode loaded in " , vae_loaded - sched_loaded , "sec" )
429- runners ["text_encoders " ] = vmfbRunner (
430- rt_device ,
431- vmfbs ["text_encoders " ],
432- weights ["text_encoders " ],
433+ runners ["clip " ] = vmfbRunner (
434+ self . devices [ "clip" ][ "driver" ] ,
435+ vmfbs ["clip " ],
436+ weights ["clip " ],
433437 )
434438 clip_loaded = time .time ()
435439 print ("\n [LOG] Text Encoders loaded in " , clip_loaded - vae_loaded , "sec" )
@@ -500,29 +504,29 @@ def generate_images(
500504 uncond_input_ids_list = list (uncond_input_ids_dict .values ())
501505 text_encoders_inputs = [
502506 ireert .asdevicearray (
503- self .runners ["text_encoders " ].config .device , text_input_ids_list [0 ]
507+ self .runners ["clip " ].config .device , text_input_ids_list [0 ]
504508 ),
505509 ireert .asdevicearray (
506- self .runners ["text_encoders " ].config .device , text_input_ids_list [1 ]
510+ self .runners ["clip " ].config .device , text_input_ids_list [1 ]
507511 ),
508512 ireert .asdevicearray (
509- self .runners ["text_encoders " ].config .device , text_input_ids_list [2 ]
513+ self .runners ["clip " ].config .device , text_input_ids_list [2 ]
510514 ),
511515 ireert .asdevicearray (
512- self .runners ["text_encoders " ].config .device , uncond_input_ids_list [0 ]
516+ self .runners ["clip " ].config .device , uncond_input_ids_list [0 ]
513517 ),
514518 ireert .asdevicearray (
515- self .runners ["text_encoders " ].config .device , uncond_input_ids_list [1 ]
519+ self .runners ["clip " ].config .device , uncond_input_ids_list [1 ]
516520 ),
517521 ireert .asdevicearray (
518- self .runners ["text_encoders " ].config .device , uncond_input_ids_list [2 ]
522+ self .runners ["clip " ].config .device , uncond_input_ids_list [2 ]
519523 ),
520524 ]
521525
522526 # Tokenize prompt and negative prompt.
523527 encode_prompts_start = time .time ()
524528 prompt_embeds , pooled_prompt_embeds = self .runners [
525- "text_encoders "
529+ "clip "
526530 ].ctx .modules .compiled_text_encoder ["encode_tokens" ](* text_encoders_inputs )
527531 encode_prompts_end = time .time ()
528532
@@ -690,6 +694,34 @@ def run_diffusers_cpu(
690694 mlirs = copy .deepcopy (map )
691695 vmfbs = copy .deepcopy (map )
692696 weights = copy .deepcopy (map )
697+
698+ if any (x for x in [args .clip_device , args .mmdit_device , args .vae_device ]):
699+ assert all (
700+ x for x in [args .clip_device , args .mmdit_device , args .vae_device ]
701+ ), "Please specify device for all submodels or pass --device for all submodels."
702+ assert all (
703+ x for x in [args .clip_target , args .mmdit_target , args .vae_target ]
704+ ), "Please specify target triple for all submodels or pass --iree_target_triple for all submodels."
705+ args .device = "hybrid"
706+ args .iree_target_triple = "_" .join ([args .clip_target , args .mmdit_target , args .vae_target ])
707+ else :
708+ args .clip_device = args .device
709+ args .mmdit_device = args .device
710+ args .vae_device = args .device
711+ args .clip_target = args .iree_target_triple
712+ args .mmdit_target = args .iree_target_triple
713+ args .vae_target = args .iree_target_triple
714+
715+ devices = {
716+ "clip" : args .clip_device ,
717+ "mmdit" : args .mmdit_device ,
718+ "vae" : args .vae_device ,
719+ }
720+ targets = {
721+ "clip" : args .clip_target ,
722+ "mmdit" : args .mmdit_target ,
723+ "vae" : args .vae_target ,
724+ }
693725 ireec_flags = {
694726 "clip" : args .ireec_flags + args .clip_flags ,
695727 "mmdit" : args .ireec_flags + args .unet_flags ,
@@ -705,6 +737,7 @@ def run_diffusers_cpu(
705737 str (args .max_length ),
706738 args .precision ,
707739 args .device ,
740+ args .iree_target_triple ,
708741 ]
709742 if args .decomp_attn :
710743 pipe_id_list .append ("decomp" )
@@ -730,8 +763,8 @@ def run_diffusers_cpu(
730763 args .max_length ,
731764 args .batch_size ,
732765 args .num_inference_steps ,
733- args . device ,
734- args . iree_target_triple ,
766+ devices ,
767+ targets ,
735768 ireec_flags ,
736769 args .attn_spec ,
737770 args .decomp_attn ,
@@ -747,7 +780,7 @@ def run_diffusers_cpu(
747780 vmfbs .pop ("scheduler" )
748781 weights .pop ("scheduler" )
749782 sd3_pipe .load_pipeline (
750- vmfbs , weights , args .rt_device , args . compiled_pipeline , args .split_scheduler
783+ vmfbs , weights , args .compiled_pipeline , args .split_scheduler
751784 )
752785 sd3_pipe .generate_images (
753786 args .prompt ,
0 commit comments