Skip to content

Commit 7bc8b92

Browse files
chaowenguo0hlkyyiyixuxu
authored
add callable object to convert frame into control_frame to reduce cpu memory usage. (#10501)
* Update rerender_a_video.py * Update rerender_a_video.py * Update examples/community/rerender_a_video.py Co-authored-by: hlky <[email protected]> --------- Co-authored-by: hlky <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent f0c6d97 commit 7bc8b92

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

examples/community/rerender_a_video.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def __call__(
632632
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
633633
instead.
634634
frames (`List[np.ndarray]` or `torch.Tensor`): The input images to be used as the starting point for the image generation process.
635-
control_frames (`List[np.ndarray]` or `torch.Tensor`): The ControlNet input images condition to provide guidance to the `unet` for generation.
635+
control_frames (`List[np.ndarray]` or `torch.Tensor` or `Callable`): The ControlNet input images condition to provide guidance to the `unet` for generation or any callable object to convert frame to control_frame.
636636
strength ('float'): SDEdit strength.
637637
num_inference_steps (`int`, *optional*, defaults to 50):
638638
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -789,7 +789,7 @@ def __call__(
789789
# Currently we only support single control
790790
if isinstance(controlnet, ControlNetModel):
791791
control_image = self.prepare_control_image(
792-
image=control_frames[0],
792+
image=control_frames(frames[0]) if callable(control_frames) else control_frames[0],
793793
width=width,
794794
height=height,
795795
batch_size=batch_size,
@@ -924,7 +924,7 @@ def __call__(
924924
for idx in range(1, len(frames)):
925925
image = frames[idx]
926926
prev_image = frames[idx - 1]
927-
control_image = control_frames[idx]
927+
control_image = control_frames(image) if callable(control_frames) else control_frames[idx]
928928
# 5.1 prepare frames
929929
image = self.image_processor.preprocess(image).to(dtype=self.dtype)
930930
prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype)

0 commit comments

Comments
 (0)