Skip to content

Commit eafd373

Browse files
authored
Merge pull request #549 from Kosinkadink/develop
Fix memory usage estimation with upcoming ComfyUI PR
2 parents 3795f51 + 6379b5e commit eafd373

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

animatediff/sampling.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,25 @@ def _diffusion_model_groupnormed_wrapper(executor, *args, **kwargs):
257257
##################################################################################
258258

259259

260-
def apply_params_to_motion_models(helper: ModelPatcherHelper, params: InjectionParams):
260+
def create_prepare_sampling_wrapper(model_options: dict, params: InjectionParams):
261+
# keep backwards compatibility
262+
if hasattr(WrappersMP, "PREPARE_SAMPLING"):
263+
comfy.patcher_extension.add_wrapper_with_key(WrappersMP.PREPARE_SAMPLING,
264+
"ADE_prepare_sampling",
265+
_prepare_sampling_wrapper_factory(params),
266+
model_options, is_model_options=True)
267+
268+
269+
def _prepare_sampling_wrapper_factory(params: InjectionParams):
270+
def _prepare_sampling_wrapper(executor, model: ModelPatcher, noise_shape: Tensor, *args, **kwargs):
271+
# TODO: handle various dims instead of defaulting to 0th
272+
# limit noise_shape length to context_length for more accurate vram use estimation
273+
noise_shape = [min(noise_shape[0], params.context_options.context_length)] + list(noise_shape[1:])
274+
return executor(model, noise_shape, *args, **kwargs)
275+
return _prepare_sampling_wrapper
276+
277+
278+
def apply_params_to_motion_models(helper: ModelPatcherHelper, params: InjectionParams, model_options: dict[str]):
261279
params = params.clone()
262280
for context in params.context_options.contexts:
263281
if context.context_schedule == ContextSchedules.VIEW_AS_CONTEXT:
@@ -273,6 +291,7 @@ def apply_params_to_motion_models(helper: ModelPatcherHelper, params: InjectionP
273291
enough_latents = False
274292
if params.context_options.context_length and enough_latents:
275293
logger.info(f"Sliding context window sampling activated - latents passed in ({params.full_length}) greater than context_length {params.context_options.context_length}.")
294+
create_prepare_sampling_wrapper(model_options, params)
276295
else:
277296
logger.info(f"Regular sampling activated - latents passed in ({params.full_length}) less or equal to context_length {params.context_options.context_length}.")
278297
params.reset_context()
@@ -440,7 +459,7 @@ def outer_sample_wrapper(executor: WrapperExecutor, *args, **kwargs):
440459
seed = args[-1]
441460

442461
# apply params to motion model
443-
params = apply_params_to_motion_models(helper, params)
462+
params = apply_params_to_motion_models(helper, params, model_options=guider.model_options)
444463

445464
# store and inject funtions
446465
function_injections.inject_functions(helper, params, guider.model_options)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-animatediff-evolved"
33
description = "Improved AnimateDiff integration for ComfyUI."
4-
version = "1.5.3"
4+
version = "1.5.4"
55
license = { file = "LICENSE" }
66
dependencies = []
77

0 commit comments

Comments
 (0)