@@ -47,6 +47,9 @@ def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16) ->
4747 self ._preprocessing_orchestrator : Optional [PreprocessingOrchestrator ] = None
4848
4949 self ._stream = None # set in install
50+ # Per-frame prepared tensor cache to avoid per-step device/dtype alignment and batch repeats
51+ self ._prepared_cache : Optional [Dict [str , Any ]] = None
52+ self ._images_version : int = 0
5053
5154 # ---------- Public API (used by wrapper in a later step) ----------
5255 def install (self , stream ) -> None :
@@ -63,6 +66,8 @@ def install(self, stream) -> None:
6366 setattr (stream , 'controlnets' , self .controlnets )
6467 setattr (stream , 'controlnet_scales' , self .controlnet_scales )
6568 setattr (stream , 'preprocessors' , self .preprocessors )
69+ # Reset caches on install
70+ self ._prepared_cache = None
6671
6772 def add_controlnet (self , cfg : ControlNetConfig , control_image : Optional [Union [str , Any , torch .Tensor ]] = None ) -> None :
6873 model = self ._load_pytorch_controlnet_model (cfg .model_id )
@@ -93,6 +98,18 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st
9398 except Exception :
9499 pass
95100
101+ # Align preprocessor target size with stream resolution once (avoid double-resize later)
102+ try :
103+ if hasattr (preproc , 'params' ) and isinstance (getattr (preproc , 'params' ), dict ):
104+ preproc .params ['image_width' ] = int (self ._stream .width )
105+ preproc .params ['image_height' ] = int (self ._stream .height )
106+ if hasattr (preproc , 'image_width' ):
107+ setattr (preproc , 'image_width' , int (self ._stream .width ))
108+ if hasattr (preproc , 'image_height' ):
109+ setattr (preproc , 'image_height' , int (self ._stream .height ))
110+ except Exception :
111+ pass
112+
96113 image_tensor : Optional [torch .Tensor ] = None
97114 if control_image is not None and self ._preprocessing_orchestrator is not None :
98115 image_tensor = self ._prepare_control_image (control_image , preproc )
@@ -103,6 +120,9 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st
103120 self .controlnet_scales .append (float (cfg .conditioning_scale ))
104121 self .preprocessors .append (preproc )
105122 self .enabled_list .append (bool (cfg .enabled ))
123+ # Invalidate prepared cache and bump version when graph changes
124+ self ._prepared_cache = None
125+ self ._images_version += 1
106126
107127 def update_control_image_efficient (self , control_image : Union [str , Any , torch .Tensor ], index : Optional [int ] = None ) -> None :
108128 if self ._preprocessing_orchestrator is None :
@@ -134,6 +154,9 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te
134154 with self ._collections_lock :
135155 if processed is not None and index < len (self .controlnet_images ):
136156 self .controlnet_images [index ] = processed
157+ # Invalidate prepared cache and bump version for per-frame reuse
158+ self ._prepared_cache = None
159+ self ._images_version += 1
137160 return
138161
139162 # Use intelligent pipelining (automatically detects feedback preprocessors and switches to sync)
@@ -154,6 +177,9 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te
154177 for i , img in enumerate (processed_images ):
155178 if img is not None and i < len (self .controlnet_images ):
156179 self .controlnet_images [i ] = img
180+ # Invalidate prepared cache and bump version after bulk update
181+ self ._prepared_cache = None
182+ self ._images_version += 1
157183
158184 def update_controlnet_scale (self , index : int , scale : float ) -> None :
159185 with self ._collections_lock :
@@ -177,6 +203,9 @@ def remove_controlnet(self, index: int) -> None:
177203 del self .preprocessors [index ]
178204 if index < len (self .enabled_list ):
179205 del self .enabled_list [index ]
206+ # Invalidate prepared cache and bump version
207+ self ._prepared_cache = None
208+ self ._images_version += 1
180209
181210 def reorder_controlnets_by_model_ids (self , desired_model_ids : List [str ]) -> None :
182211 """Reorder internal collections to match the desired model_id order.
@@ -295,7 +324,42 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
295324 down_samples_list : List [List [torch .Tensor ]] = []
296325 mid_samples_list : List [torch .Tensor ] = []
297326
298- for cn , img , scale in zip (active_controlnets , active_images , active_scales ):
327+ # Prepare control images once per frame for current device/dtype/batch
328+ try :
329+ main_batch = x_t .shape [0 ]
330+ cache_ok = (
331+ isinstance (self ._prepared_cache , dict )
332+ and self ._prepared_cache .get ('device' ) == x_t .device
333+ and self ._prepared_cache .get ('dtype' ) == x_t .dtype
334+ and self ._prepared_cache .get ('batch' ) == main_batch
335+ and self ._prepared_cache .get ('version' ) == self ._images_version
336+ )
337+ if not cache_ok :
338+ prepared : List [Optional [torch .Tensor ]] = [None ] * len (self .controlnet_images )
339+ for i , base_img in enumerate (self .controlnet_images ):
340+ if base_img is None :
341+ continue
342+ cur = base_img
343+ if cur .dim () == 4 and cur .shape [0 ] != main_batch :
344+ if cur .shape [0 ] == 1 :
345+ cur = cur .repeat (main_batch , 1 , 1 , 1 )
346+ else :
347+ repeat_factor = max (1 , main_batch // cur .shape [0 ])
348+ cur = cur .repeat (repeat_factor , 1 , 1 , 1 )
349+ cur = cur .to (device = x_t .device , dtype = x_t .dtype )
350+ prepared [i ] = cur
351+ self ._prepared_cache = {
352+ 'device' : x_t .device ,
353+ 'dtype' : x_t .dtype ,
354+ 'batch' : main_batch ,
355+ 'version' : self ._images_version ,
356+ 'prepared' : prepared ,
357+ }
358+ prepared_images : List [Optional [torch .Tensor ]] = self ._prepared_cache ['prepared' ] if self ._prepared_cache else [None ] * len (self .controlnet_images )
359+ except Exception :
360+ prepared_images = active_images # Fallback to per-step path if cache prep fails
361+
362+ for cn , img , scale , idx_i in zip (active_controlnets , active_images , active_scales , active_indices ):
299363 # Swap to TRT engine if compiled and available for this model_id
300364 try :
301365 model_id = getattr (cn , 'model_id' , None )
@@ -304,22 +368,10 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
304368 # Swapped to TRT engine
305369 except Exception :
306370 pass
307- current_img = img
371+ # Pull from prepared cache if available
372+ current_img = prepared_images [idx_i ] if 'prepared_images' in locals () and prepared_images and idx_i < len (prepared_images ) and prepared_images [idx_i ] is not None else img
308373 if current_img is None :
309374 continue
310- # Ensure control image batch matches latent batch for TRT engines
311- try :
312- main_batch = x_t .shape [0 ]
313- if current_img .dim () == 4 and current_img .shape [0 ] != main_batch :
314- if current_img .shape [0 ] == 1 :
315- current_img = current_img .repeat (main_batch , 1 , 1 , 1 )
316- else :
317- repeat_factor = max (1 , main_batch // current_img .shape [0 ])
318- current_img = current_img .repeat (repeat_factor , 1 , 1 , 1 )
319- # Align device/dtype with latent for engine inputs
320- current_img = current_img .to (device = x_t .device , dtype = x_t .dtype )
321- except Exception :
322- pass
323375 kwargs = base_kwargs .copy ()
324376 kwargs ['controlnet_cond' ] = current_img
325377 kwargs ['conditioning_scale' ] = float (scale )
0 commit comments