Skip to content
This repository was archived by the owner on Dec 26, 2025. It is now read-only.

Commit 93f71cc

Browse files
removet2i cache
1 parent f3514f8 commit 93f71cc

File tree

1 file changed

+8
-24
lines changed

1 file changed

+8
-24
lines changed

src/streamdiffusion/modules/t2i_adapter_module.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,6 @@ class T2IAdapterModule(ControlNetModule):
229229

230230
def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16) -> None:
231231
super().__init__(device=device, dtype=dtype)
232-
# Cached features per adapter (recomputed when image changes)
233-
self._cached_controls: List[Optional[Dict[str, List[Optional[torch.Tensor]]]]] = []
234232
# Stream reference set in install
235233
self._stream = None
236234

@@ -291,7 +289,6 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st
291289
self.controlnet_scales.append(float(cfg.conditioning_scale))
292290
self.preprocessors.append(preproc)
293291
self.enabled_list.append(bool(cfg.enabled))
294-
self._cached_controls.append(None)
295292

296293

297294
# Control image updates, scale/enable toggles, reordering, and removal
@@ -332,7 +329,6 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
332329
active_adapters = [self.controlnets[i] for i in active_indices]
333330
active_images = [self.controlnet_images[i] for i in active_indices]
334331
active_scales = [self.controlnet_scales[i] for i in active_indices]
335-
cached_controls = [self._cached_controls[i] for i in active_indices]
336332

337333
# Compute or reuse cached controls per adapter
338334
down_samples_list: List[List[torch.Tensor]] = []
@@ -343,7 +339,7 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
343339
# TRT path: per-block aggregated residual: List[Tensor]
344340
block_down_lists: List[List[torch.Tensor]] = []
345341

346-
for idx, (adapter, image_tensor, scale, cached) in enumerate(zip(active_adapters, active_images, active_scales, cached_controls)):
342+
for idx, (adapter, image_tensor, scale) in enumerate(zip(active_adapters, active_images, active_scales)):
347343
if image_tensor is None:
348344
continue
349345
# Align with latent batch/device/dtype
@@ -360,25 +356,13 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
360356
except Exception:
361357
pass
362358

363-
controls = cached
364-
if controls is None:
365-
try:
366-
adapter.to(device=ctx.x_t_latent.device, dtype=ctx.x_t_latent.dtype)
367-
with torch.no_grad():
368-
controls = adapter(current_img)
369-
370-
except Exception as e:
371-
logger.error(f"build_unet_hook: adapter forward failed: {e}")
372-
continue
373-
# Store into cache (under lock)
374-
with self._collections_lock:
375-
try:
376-
orig_index = active_indices[idx]
377-
self._cached_controls[orig_index] = controls
378-
except Exception:
379-
pass
380-
else:
381-
pass
359+
try:
360+
adapter.to(device=ctx.x_t_latent.device, dtype=ctx.x_t_latent.dtype)
361+
with torch.no_grad():
362+
controls = adapter(current_img)
363+
except Exception as e:
364+
logger.error(f"build_unet_hook: adapter forward failed: {e}")
365+
continue
382366

383367
# Extract down and middle lists
384368
down_list = controls.get("input", []) if isinstance(controls, dict) else []

0 commit comments

Comments
 (0)