@@ -257,7 +257,25 @@ def _diffusion_model_groupnormed_wrapper(executor, *args, **kwargs):
257257##################################################################################
258258
259259
260- def apply_params_to_motion_models (helper : ModelPatcherHelper , params : InjectionParams ):
260+ def create_prepare_sampling_wrapper (model_options : dict , params : InjectionParams ):
261+ # keep backwards compatibility
262+ if hasattr (WrappersMP , "PREPARE_SAMPLING" ):
263+ comfy .patcher_extension .add_wrapper_with_key (WrappersMP .PREPARE_SAMPLING ,
264+ "ADE_prepare_sampling" ,
265+ _prepare_sampling_wrapper_factory (params ),
266+ model_options , is_model_options = True )
267+
268+
269+ def _prepare_sampling_wrapper_factory (params : InjectionParams ):
270+ def _prepare_sampling_wrapper (executor , model : ModelPatcher , noise_shape : Tensor , * args , ** kwargs ):
271+ # TODO: handle various dims instead of defaulting to 0th
272+ # limit noise_shape length to context_length for more accurate vram use estimation
273+ noise_shape = [min (noise_shape [0 ], params .context_options .context_length )] + list (noise_shape [1 :])
274+ return executor (model , noise_shape , * args , ** kwargs )
275+ return _prepare_sampling_wrapper
276+
277+
278+ def apply_params_to_motion_models (helper : ModelPatcherHelper , params : InjectionParams , model_options : dict [str ]):
261279 params = params .clone ()
262280 for context in params .context_options .contexts :
263281 if context .context_schedule == ContextSchedules .VIEW_AS_CONTEXT :
@@ -273,6 +291,7 @@ def apply_params_to_motion_models(helper: ModelPatcherHelper, params: InjectionP
273291 enough_latents = False
274292 if params .context_options .context_length and enough_latents :
275293 logger .info (f"Sliding context window sampling activated - latents passed in ({ params .full_length } ) greater than context_length { params .context_options .context_length } ." )
294+ create_prepare_sampling_wrapper (model_options , params )
276295 else :
277296 logger .info (f"Regular sampling activated - latents passed in ({ params .full_length } ) less or equal to context_length { params .context_options .context_length } ." )
278297 params .reset_context ()
@@ -440,7 +459,7 @@ def outer_sample_wrapper(executor: WrapperExecutor, *args, **kwargs):
440459 seed = args [- 1 ]
441460
442461 # apply params to motion model
443- params = apply_params_to_motion_models (helper , params )
462+ params = apply_params_to_motion_models (helper , params , model_options = guider . model_options )
444463
445464 # store and inject funtions
446465 function_injections .inject_functions (helper , params , guider .model_options )
0 commit comments