@@ -51,26 +51,36 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
5151
5252
5353class IndexListContextWindow (ContextWindowABC ):
54- def __init__ (self , index_list : list [int ], dim : int = 0 ):
54+ def __init__ (self , index_list : list [int ], dim : int = 0 , total_frames : int = 0 ):
5555 self .index_list = index_list
5656 self .context_length = len (index_list )
5757 self .dim = dim
58+ self .total_frames = total_frames
59+ self .center_ratio = (min (index_list ) + max (index_list )) / (2 * total_frames )
5860
59- def get_tensor (self , full : torch .Tensor , device = None , dim = None ) -> torch .Tensor :
61+ def get_tensor (self , full : torch .Tensor , device = None , dim = None , retain_index_list = [] ) -> torch .Tensor :
6062 if dim is None :
6163 dim = self .dim
6264 if dim == 0 and full .shape [dim ] == 1 :
6365 return full
64- idx = [slice (None )] * dim + [self .index_list ]
65- return full [idx ].to (device )
66+ idx = tuple ([slice (None )] * dim + [self .index_list ])
67+ window = full [idx ]
68+ if retain_index_list :
69+ idx = tuple ([slice (None )] * dim + [retain_index_list ])
70+ window [idx ] = full [idx ]
71+ return window .to (device )
6672
6773 def add_window (self , full : torch .Tensor , to_add : torch .Tensor , dim = None ) -> torch .Tensor :
6874 if dim is None :
6975 dim = self .dim
70- idx = [slice (None )] * dim + [self .index_list ]
76+ idx = tuple ( [slice (None )] * dim + [self .index_list ])
7177 full [idx ] += to_add
7278 return full
7379
80+ def get_region_index (self , num_regions : int ) -> int :
81+ region_idx = int (self .center_ratio * num_regions )
82+ return min (max (region_idx , 0 ), num_regions - 1 )
83+
7484
7585class IndexListCallbacks :
7686 EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
@@ -94,7 +104,8 @@ class ContextFuseMethod:
94104
95105ContextResults = collections .namedtuple ("ContextResults" , ['window_idx' , 'sub_conds_out' , 'sub_conds' , 'window' ])
96106class IndexListContextHandler (ContextHandlerABC ):
97- def __init__ (self , context_schedule : ContextSchedule , fuse_method : ContextFuseMethod , context_length : int = 1 , context_overlap : int = 0 , context_stride : int = 1 , closed_loop = False , dim = 0 ):
107+ def __init__ (self , context_schedule : ContextSchedule , fuse_method : ContextFuseMethod , context_length : int = 1 , context_overlap : int = 0 , context_stride : int = 1 ,
108+ closed_loop : bool = False , dim :int = 0 , freenoise : bool = False , cond_retain_index_list : list [int ]= [], split_conds_to_windows : bool = False ):
98109 self .context_schedule = context_schedule
99110 self .fuse_method = fuse_method
100111 self .context_length = context_length
@@ -103,13 +114,18 @@ def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMe
103114 self .closed_loop = closed_loop
104115 self .dim = dim
105116 self ._step = 0
117+ self .freenoise = freenoise
118+ self .cond_retain_index_list = [int (x .strip ()) for x in cond_retain_index_list .split ("," )] if cond_retain_index_list else []
119+ self .split_conds_to_windows = split_conds_to_windows
106120
107121 self .callbacks = {}
108122
109123 def should_use_context (self , model : BaseModel , conds : list [list [dict ]], x_in : torch .Tensor , timestep : torch .Tensor , model_options : dict [str ]) -> bool :
110124 # for now, assume first dim is batch - should have stored on BaseModel in actual implementation
111125 if x_in .size (self .dim ) > self .context_length :
112- logging .info (f"Using context windows { self .context_length } for { x_in .size (self .dim )} frames." )
126+ logging .info (f"Using context windows { self .context_length } with overlap { self .context_overlap } for { x_in .size (self .dim )} frames." )
127+ if self .cond_retain_index_list :
128+ logging .info (f"Retaining original cond for indexes: { self .cond_retain_index_list } " )
113129 return True
114130 return False
115131
@@ -123,6 +139,11 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
123139 return None
124140 # reuse or resize cond items to match context requirements
125141 resized_cond = []
142+ # if multiple conds, split based on primary region
143+ if self .split_conds_to_windows and len (cond_in ) > 1 :
144+ region = window .get_region_index (len (cond_in ))
145+ logging .info (f"Splitting conds to windows; using region { region } for window { window [0 ]} -{ window [- 1 ]} with center ratio { window .center_ratio :.3f} " )
146+ cond_in = [cond_in [region ]]
126147 # cond object is a list containing a dict - outer list is irrelevant, so just loop through it
127148 for actual_cond in cond_in :
128149 resized_actual_cond = actual_cond .copy ()
@@ -146,12 +167,19 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
146167 # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
147168 for cond_key , cond_value in new_cond_item .items ():
148169 if isinstance (cond_value , torch .Tensor ):
149- if cond_value .ndim < self .dim and cond_value .size (0 ) == x_in .size (self .dim ):
170+ if (self .dim < cond_value .ndim and cond_value (self .dim ) == x_in .size (self .dim )) or \
171+ (cond_value .ndim < self .dim and cond_value .size (0 ) == x_in .size (self .dim )):
150172 new_cond_item [cond_key ] = window .get_tensor (cond_value , device )
173+ # Handle audio_embed (temporal dim is 1)
174+ elif cond_key == "audio_embed" and hasattr (cond_value , "cond" ) and isinstance (cond_value .cond , torch .Tensor ):
175+ audio_cond = cond_value .cond
176+ if audio_cond .ndim > 1 and audio_cond .size (1 ) == x_in .size (self .dim ):
177+ new_cond_item [cond_key ] = cond_value ._copy_with (window .get_tensor (audio_cond , device , dim = 1 ))
151178 # if has cond that is a Tensor, check if needs to be subset
152179 elif hasattr (cond_value , "cond" ) and isinstance (cond_value .cond , torch .Tensor ):
153- if cond_value .cond .ndim < self .dim and cond_value .cond .size (0 ) == x_in .size (self .dim ):
154- new_cond_item [cond_key ] = cond_value ._copy_with (window .get_tensor (cond_value .cond , device ))
180+ if (self .dim < cond_value .cond .ndim and cond_value .cond .size (self .dim ) == x_in .size (self .dim )) or \
181+ (cond_value .cond .ndim < self .dim and cond_value .cond .size (0 ) == x_in .size (self .dim )):
182+ new_cond_item [cond_key ] = cond_value ._copy_with (window .get_tensor (cond_value .cond , device , retain_index_list = self .cond_retain_index_list ))
155183 elif cond_key == "num_video_frames" : # for SVD
156184 new_cond_item [cond_key ] = cond_value ._copy_with (cond_value .cond )
157185 new_cond_item [cond_key ].cond = window .context_length
@@ -164,7 +192,7 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
164192 return resized_cond
165193
166194 def set_step (self , timestep : torch .Tensor , model_options : dict [str ]):
167- mask = torch .isclose (model_options ["transformer_options" ]["sample_sigmas" ], timestep , rtol = 0.0001 )
195+ mask = torch .isclose (model_options ["transformer_options" ]["sample_sigmas" ], timestep [ 0 ] , rtol = 0.0001 )
168196 matches = torch .nonzero (mask )
169197 if torch .numel (matches ) == 0 :
170198 raise Exception ("No sample_sigmas matched current timestep; something went wrong." )
@@ -173,7 +201,7 @@ def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
173201 def get_context_windows (self , model : BaseModel , x_in : torch .Tensor , model_options : dict [str ]) -> list [IndexListContextWindow ]:
174202 full_length = x_in .size (self .dim ) # TODO: choose dim based on model
175203 context_windows = self .context_schedule .func (full_length , self , model_options )
176- context_windows = [IndexListContextWindow (window , dim = self .dim ) for window in context_windows ]
204+ context_windows = [IndexListContextWindow (window , dim = self .dim , total_frames = full_length ) for window in context_windows ]
177205 return context_windows
178206
179207 def execute (self , calc_cond_batch : Callable , model : BaseModel , conds : list [list [dict ]], x_in : torch .Tensor , timestep : torch .Tensor , model_options : dict [str ]):
@@ -250,8 +278,8 @@ def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_
250278 prev_weight = (bias_total / (bias_total + bias ))
251279 new_weight = (bias / (bias_total + bias ))
252280 # account for dims of tensors
253- idx_window = [slice (None )] * self .dim + [idx ]
254- pos_window = [slice (None )] * self .dim + [pos ]
281+ idx_window = tuple ( [slice (None )] * self .dim + [idx ])
282+ pos_window = tuple ( [slice (None )] * self .dim + [pos ])
255283 # apply new values
256284 conds_final [i ][idx_window ] = conds_final [i ][idx_window ] * prev_weight + sub_conds_out [i ][pos_window ] * new_weight
257285 biases_final [i ][idx ] = bias_total + bias
@@ -287,6 +315,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
287315 )
288316
289317
318+ def _sampler_sample_wrapper (executor , guider , sigmas , extra_args , callback , noise , * args , ** kwargs ):
319+ model_options = extra_args .get ("model_options" , None )
320+ if model_options is None :
321+ raise Exception ("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong." )
322+ handler : IndexListContextHandler = model_options .get ("context_handler" , None )
323+ if handler is None :
324+ raise Exception ("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong." )
325+ if not handler .freenoise :
326+ return executor (guider , sigmas , extra_args , callback , noise , * args , ** kwargs )
327+ noise = apply_freenoise (noise , handler .dim , handler .context_length , handler .context_overlap , extra_args ["seed" ])
328+
329+ return executor (guider , sigmas , extra_args , callback , noise , * args , ** kwargs )
330+
331+
332+ def create_sampler_sample_wrapper (model : ModelPatcher ):
333+ model .add_wrapper_with_key (
334+ comfy .patcher_extension .WrappersMP .SAMPLER_SAMPLE ,
335+ "ContextWindows_sampler_sample" ,
336+ _sampler_sample_wrapper
337+ )
338+
339+
290340def match_weights_to_dim (weights : list [float ], x_in : torch .Tensor , dim : int , device = None ) -> torch .Tensor :
291341 total_dims = len (x_in .shape )
292342 weights_tensor = torch .Tensor (weights ).to (device = device )
@@ -538,3 +588,29 @@ def shift_window_to_end(window: list[int], num_frames: int):
538588 for i in range (len (window )):
539589 # 2) add end_delta to each val to slide windows to end
540590 window [i ] = window [i ] + end_delta
591+
592+
593+ # https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
594+ def apply_freenoise (noise : torch .Tensor , dim : int , context_length : int , context_overlap : int , seed : int ):
595+ logging .info ("Context windows: Applying FreeNoise" )
596+ generator = torch .Generator (device = 'cpu' ).manual_seed (seed )
597+ latent_video_length = noise .shape [dim ]
598+ delta = context_length - context_overlap
599+
600+ for start_idx in range (0 , latent_video_length - context_length , delta ):
601+ place_idx = start_idx + context_length
602+
603+ actual_delta = min (delta , latent_video_length - place_idx )
604+ if actual_delta <= 0 :
605+ break
606+
607+ list_idx = torch .randperm (actual_delta , generator = generator , device = 'cpu' ) + start_idx
608+
609+ source_slice = [slice (None )] * noise .ndim
610+ source_slice [dim ] = list_idx
611+ target_slice = [slice (None )] * noise .ndim
612+ target_slice [dim ] = slice (place_idx , place_idx + actual_delta )
613+
614+ noise [tuple (target_slice )] = noise [tuple (source_slice )]
615+
616+ return noise
0 commit comments