@@ -41,6 +41,10 @@ def __init__(self, device: str = "CPU", dtype: torch.dtype = torch.float16):
4141 self .runtime = Runtime .get ()
4242 self ._initialized = False
4343
44+ # Cumulative timing metrics
45+ self .models_load_time = 0.0
46+ self .exec_time = 0.0
47+
4448 def load_tokenizer (self , vocab_path : str ):
4549 """Load CLIP tokenizer"""
4650 try :
@@ -89,8 +93,6 @@ def encode_prompt(self, prompt: str):
8993 return None
9094
9195 try :
92- start_time = time .time ()
93-
9496 inputs = self .tokenizer (
9597 prompt ,
9698 padding = "max_length" ,
@@ -99,10 +101,17 @@ def encode_prompt(self, prompt: str):
99101 return_tensors = "pt" ,
100102 )
101103
104+ load_start = time .time ()
102105 text_encoder_method = self .models ["text_encoder" ].load_method ("forward" )
106+ load_time = time .time () - load_start
107+ self .models_load_time += load_time
108+
109+ exec_start = time .time ()
103110 embeddings = text_encoder_method .execute ([inputs .input_ids ])[0 ]
111+ exec_time = time .time () - exec_start
112+ self .exec_time += exec_time
104113
105- logger .info (f"Text encoded ( { time . time () - start_time :.3f} s) " )
114+ logger .info (f"Text encoder - Load: { load_time :.3f } s, Execute: { exec_time :.3f} s" )
106115 return embeddings
107116 except Exception as e :
108117 logger .error (f"Failed to encode prompt: { e } " )
@@ -136,7 +145,11 @@ def denoise_latents(
136145 self .scheduler .set_timesteps (num_steps )
137146
138147 # Get UNet method
148+ load_start = time .time ()
139149 unet_method = self .models ["unet" ].load_method ("forward" )
150+ load_time = time .time () - load_start
151+ self .models_load_time += load_time
152+ logger .info (f"UNet - Load: { load_time :.3f} s" )
140153
141154 # Denoising loop
142155 logger .info (f"Running LCM denoising with { num_steps } steps..." )
@@ -164,10 +177,12 @@ def denoise_latents(
164177 f" Step { step + 1 } /{ num_steps } completed ({ time .time () - step_start :.3f} s)"
165178 )
166179
167- denoise_elapsed = time .time () - denoise_start
180+ exec_time = time .time () - denoise_start
181+ self .exec_time += exec_time
168182 logger .info (
169- f"Denoising completed ( { denoise_elapsed :.3f} s, avg { denoise_elapsed / num_steps :.3f} s/step) "
183+ f"UNet - Execute: { exec_time :.3f} s, avg { exec_time / num_steps :.3f} s/step"
170184 )
185+
171186 return latents
172187 except Exception as e :
173188 logger .error (f"Failed during denoising: { e } " )
@@ -180,17 +195,30 @@ def decode_image(self, latents: torch.Tensor):
180195 return None
181196
182197 try :
183- start_time = time .time ()
184-
198+ load_start = time .time ()
185199 vae_method = self .models ["vae_decoder" ].load_method ("forward" )
200+ load_time = time .time () - load_start
201+ self .models_load_time += load_time
202+
203+ exec_start = time .time ()
186204 decoded_image = vae_method .execute ([latents ])[0 ]
205+ exec_time = time .time () - exec_start
206+ self .exec_time += exec_time
187207
188208 # Convert from (1, 3, 512, 512) CHW to (512, 512, 3) HWC
209+ conversion_start = time .time ()
189210 decoded_image = decoded_image .squeeze (0 ).permute (1 , 2 , 0 )
190211 decoded_image = (decoded_image * 255 ).clamp (0 , 255 ).to (torch .uint8 )
191-
192212 image = Image .fromarray (decoded_image .numpy ())
193- logger .info (f"Image decoded ({ time .time () - start_time :.3f} s)" )
213+ postprocess_time = time .time () - conversion_start
214+ self .exec_time += postprocess_time
215+
216+ logger .info (
217+ f"VAE decoder - Load: { load_time :.3f} s, "
218+ f"Execute: { exec_time :.3f} s, "
219+ f"Post-process: { postprocess_time :.3f} s"
220+ )
221+
194222 return image
195223 except Exception as e :
196224 logger .error (f"Failed to decode image: { e } " )
@@ -213,6 +241,10 @@ def generate_image(
213241 logger .info (f"Steps: { num_steps } | Guidance: { guidance_scale } | Seed: { seed } " )
214242 logger .info ("=" * 60 )
215243
244+ # Reset cumulative timers
245+ self .models_load_time = 0.0
246+ self .exec_time = 0.0
247+
216248 total_start = time .time ()
217249
218250 text_embeddings = self .encode_prompt (prompt )
@@ -227,10 +259,13 @@ def generate_image(
227259 if image is None :
228260 return None
229261
262+ total_time = time .time () - total_start
263+
230264 logger .info ("=" * 60 )
231- logger .info (
232- f"✓ Generation completed! Total time: { time .time () - total_start :.3f} s"
233- )
265+ logger .info (f"✓ Generation completed!" )
266+ logger .info (f" Total time: { total_time :.3f} s" )
267+ logger .info (f" Total load time: { self .models_load_time :.3f} s" )
268+ logger .info (f" Total Inference time: { self .exec_time :.3f} s" )
234269 logger .info ("=" * 60 )
235270 return image
236271
0 commit comments