@@ -41,14 +41,20 @@ def __init__(self, rt_device, vmfb):
4141 self .runner = vmfbRunner (rt_device , vmfb , None )
4242
4343 def initialize (self , sample ):
44- return self .runner .ctx .modules .compiled_scheduler ["run_initialize" ](sample )
44+ sample , time_ids , steps , timesteps = self .runner .ctx .modules .compiled_scheduler ["run_initialize" ](sample )
45+ return sample , time_ids , steps .to_host (), timesteps
4546
4647 def scale_model_input (self , sample , t , timesteps ):
4748 return self .runner .ctx .modules .compiled_scheduler ["run_scale" ](
4849 sample , t , timesteps
4950 )
5051
5152 def step (self , noise_pred , t , sample , guidance_scale , step_index ):
53+ print (
54+ noise_pred .to_host ()[:,:,0 ,2 ],
55+ t ,
56+ sample .to_host ()[:,:,0 ,2 ],
57+ )
5258 return self .runner .ctx .modules .compiled_scheduler ["run_step" ](
5359 noise_pred , t , sample , guidance_scale , step_index
5460 )
@@ -98,7 +104,7 @@ def initialize(self, sample):
98104 sample .type (self .dtype ),
99105 add_time_ids ,
100106 step_count ,
101- timesteps .type (self . dtype ),
107+ timesteps .type (torch . float32 ),
102108 )
103109
104110 def prepare_model_input (self , sample , t , timesteps ):
@@ -119,15 +125,11 @@ def step(self, noise_pred, t, sample, guidance_scale, i):
119125 noise_pred = noise_preds [0 ] + guidance_scale * (
120126 noise_preds [1 ] - noise_preds [0 ]
121127 )
122- if self .model .config .skip_prk_steps == True :
123- sample = self .model .step_plms (noise_pred , t , sample , return_dict = False )[0 ]
124- else :
125- sample = self .model .step (noise_pred , t , sample , return_dict = False )[0 ]
128+ sample = self .model .step (noise_pred , t , sample , return_dict = False )[0 ]
126129 return sample .type (self .dtype )
127130
128-
129- @torch .no_grad ()
130131class SharkSchedulerCPUWrapper :
132+ @torch .no_grad ()
131133 def __init__ (
132134 self , scheduler , batch_size , num_inference_steps , dest_device , latents_dtype
133135 ):
@@ -137,13 +139,16 @@ def __init__(
137139 self .dtype = latents_dtype
138140 self .batch_size = batch_size
139141 self .module .set_timesteps (num_inference_steps )
142+ self .timesteps = self .module .timesteps
140143 self .torch_dtype = (
141144 torch .float32 if latents_dtype == "float32" else torch .float16
142145 )
143146
144147 def initialize (self , sample ):
145- height = sample .shape [2 ]
146- width = sample .shape [3 ]
148+ if isinstance (sample , ireert .DeviceArray ):
149+ sample = torch .tensor (sample .to_host (), dtype = torch .float32 )
150+ height = sample .shape [2 ] * 8
151+ width = sample .shape [3 ] * 8
147152 original_size = (height , width )
148153 target_size = (height , width )
149154 crops_coords_top_left = (0 , 0 )
@@ -155,10 +160,10 @@ def initialize(self, sample):
155160 self .torch_dtype
156161 )
157162 step_indexes = torch .tensor (len (self .module .timesteps ))
158- timesteps = self .module . timesteps
163+ timesteps = self .timesteps
159164 sample = sample * self .module .init_noise_sigma
165+ print (sample , add_time_ids , step_indexes , timesteps )
160166 add_time_ids = ireert .asdevicearray (self .dest , add_time_ids , self .dtype )
161- step_indexes = ireert .asdevicearray (self .dest , step_indexes , "int64" )
162167 return sample , add_time_ids , step_indexes , timesteps
163168
164169 def scale_model_input (self , sample , t , timesteps ):
@@ -167,24 +172,27 @@ def scale_model_input(self, sample, t, timesteps):
167172 t = timesteps [t ]
168173 scaled = self .module .scale_model_input (sample , t )
169174 t = ireert .asdevicearray (self .dest , [t ], self .dtype )
175+ scaled = ireert .asdevicearray (self .dest , scaled , self .dtype )
170176 return scaled , t
171177
172- def step (self , latents , t , sample , guidance_scale , i ):
173- if isinstance (latents , ireert .DeviceArray ):
174- latents = torch .tensor (latents .to_host ())
178+ def step (self , noise_pred , t , latents , guidance_scale , i ):
175179 if isinstance (t , ireert .DeviceArray ):
176- t = self .module .timesteps [i ]
177- if isinstance (sample , ireert .DeviceArray ):
178- sample = torch .tensor (sample .to_host ())
180+ t = torch .tensor (t .to_host ())
181+ if isinstance (guidance_scale , ireert .DeviceArray ):
182+ guidance_scale = torch .tensor (guidance_scale .to_host ())
183+ noise_pred = torch .tensor (noise_pred .to_host ())
179184 if self .do_classifier_free_guidance :
180- noise_preds = latents .chunk (2 )
181- latents = noise_preds [0 ] + guidance_scale * (
182- noise_preds [1 ] - noise_preds [0 ]
183- )
185+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
186+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
187+ print (
188+ noise_pred [:,:,0 ,2 ],
189+ t ,
190+ latents [:,:,0 ,2 ],
191+ )
184192 return self .module .step (
185- latents ,
193+ noise_pred ,
186194 t ,
187- sample ,
195+ latents ,
188196 return_dict = False ,
189197 )[0 ]
190198
@@ -212,19 +220,23 @@ def export_scheduler_model(
212220 scheduler_module = SchedulingModel (
213221 hf_model_name , scheduler , height , width , batch_size , num_inference_steps , dtype
214222 )
215- vmfb_names = [
216- scheduler_id + "Scheduler" ,
217- f"bs{ batch_size } " ,
218- f"{ height } x{ width } " ,
219- precision ,
220- str (num_inference_steps ),
221- target_triple ,
222- ]
223- vmfb_name = "_" .join (vmfb_names )
224-
225223 if pipeline_dir :
224+ vmfb_names = [
225+ scheduler_id + "Scheduler" ,
226+ str (num_inference_steps ),
227+ ]
228+ vmfb_name = "_" .join (vmfb_names )
226229 safe_name = os .path .join (pipeline_dir , vmfb_name )
227230 else :
231+ vmfb_names = [
232+ scheduler_id + "Scheduler" ,
233+ f"bs{ batch_size } " ,
234+ f"{ height } x{ width } " ,
235+ precision ,
236+ str (num_inference_steps ),
237+ target_triple ,
238+ ]
239+ vmfb_name = "_" .join (vmfb_names )
228240 safe_name = utils .create_safe_name (hf_model_name , "_" + vmfb_name )
229241
230242 if input_mlir :
@@ -261,7 +273,7 @@ def export_scheduler_model(
261273 example_prep_args = (
262274 torch .empty (sample , dtype = dtype ),
263275 torch .empty (1 , dtype = torch .int64 ),
264- torch .empty ([19 ], dtype = dtype ),
276+ torch .empty ([19 ], dtype = torch . float32 ),
265277 )
266278 timesteps = torch .export .Dim ("timesteps" )
267279 prep_dynamic_args = {
0 commit comments