@@ -229,8 +229,6 @@ class T2IAdapterModule(ControlNetModule):
229229
230230 def __init__ (self , device : str = "cuda" , dtype : torch .dtype = torch .float16 ) -> None :
231231 super ().__init__ (device = device , dtype = dtype )
232- # Cached features per adapter (recomputed when image changes)
233- self ._cached_controls : List [Optional [Dict [str , List [Optional [torch .Tensor ]]]]] = []
234232 # Stream reference set in install
235233 self ._stream = None
236234
@@ -291,7 +289,6 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st
291289 self .controlnet_scales .append (float (cfg .conditioning_scale ))
292290 self .preprocessors .append (preproc )
293291 self .enabled_list .append (bool (cfg .enabled ))
294- self ._cached_controls .append (None )
295292
296293
297294 # Control image updates, scale/enable toggles, reordering, and removal
@@ -332,7 +329,6 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
332329 active_adapters = [self .controlnets [i ] for i in active_indices ]
333330 active_images = [self .controlnet_images [i ] for i in active_indices ]
334331 active_scales = [self .controlnet_scales [i ] for i in active_indices ]
335- cached_controls = [self ._cached_controls [i ] for i in active_indices ]
336332
337333 # Compute or reuse cached controls per adapter
338334 down_samples_list : List [List [torch .Tensor ]] = []
@@ -343,7 +339,7 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
343339 # TRT path: per-block aggregated residual: List[Tensor]
344340 block_down_lists : List [List [torch .Tensor ]] = []
345341
346- for idx , (adapter , image_tensor , scale , cached ) in enumerate (zip (active_adapters , active_images , active_scales , cached_controls )):
342+ for idx , (adapter , image_tensor , scale ) in enumerate (zip (active_adapters , active_images , active_scales )):
347343 if image_tensor is None :
348344 continue
349345 # Align with latent batch/device/dtype
@@ -360,25 +356,13 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
360356 except Exception :
361357 pass
362358
363- controls = cached
364- if controls is None :
365- try :
366- adapter .to (device = ctx .x_t_latent .device , dtype = ctx .x_t_latent .dtype )
367- with torch .no_grad ():
368- controls = adapter (current_img )
369-
370- except Exception as e :
371- logger .error (f"build_unet_hook: adapter forward failed: { e } " )
372- continue
373- # Store into cache (under lock)
374- with self ._collections_lock :
375- try :
376- orig_index = active_indices [idx ]
377- self ._cached_controls [orig_index ] = controls
378- except Exception :
379- pass
380- else :
381- pass
359+ try :
360+ adapter .to (device = ctx .x_t_latent .device , dtype = ctx .x_t_latent .dtype )
361+ with torch .no_grad ():
362+ controls = adapter (current_img )
363+ except Exception as e :
364+ logger .error (f"build_unet_hook: adapter forward failed: { e } " )
365+ continue
382366
383367 # Extract down and middle lists
384368 down_list = controls .get ("input" , []) if isinstance (controls , dict ) else []
0 commit comments