@@ -20,15 +20,23 @@ class ContextFuseMethod:
2020 FLAT = "flat"
2121 PYRAMID = "pyramid"
2222 RELATIVE = "relative"
23- RANDOM = "random"
24- GAUSS_SIGMA = "gauss-sigma"
25- GAUSS_SIGMA_INV = "gauss-sigma inverse"
26- DELAYED_REVERSE_SAWTOOTH = "delayed reverse sawtooth"
27- PYRAMID_SIGMA = "pyramid-sigma"
28- PYRAMID_SIGMA_INV = "pyramid-sigma inverse"
23+ OVERLAP_LINEAR = "overlap-linear"
2924
30- LIST = [PYRAMID , FLAT , DELAYED_REVERSE_SAWTOOTH , PYRAMID_SIGMA , PYRAMID_SIGMA_INV , GAUSS_SIGMA , GAUSS_SIGMA_INV , RANDOM ]
31- LIST_STATIC = [PYRAMID , RELATIVE , FLAT , DELAYED_REVERSE_SAWTOOTH , PYRAMID_SIGMA , PYRAMID_SIGMA_INV , GAUSS_SIGMA , GAUSS_SIGMA_INV , RANDOM ]
25+ RANDOM = "🔬random"
26+ RANDOM_DEPR = "random"
27+ GAUSS_SIGMA = "🔬gauss-sigma"
28+ GAUSS_SIGMA_DEPR = "gauss-sigma"
29+ GAUSS_SIGMA_INV = "🔬gauss-sigma inverse"
30+ GAUSS_SIGMA_INV_DEPR = "gauss-sigma inverse"
31+ DELAYED_REVERSE_SAWTOOTH = "🔬delayed reverse sawtooth"
32+ DELAYED_REVERSE_SAWTOOTH_DEPR = "delayed reverse sawtooth"
33+ PYRAMID_SIGMA = "🔬pyramid-sigma"
34+ PYRAMID_SIGMA_DEPR = "pyramid-sigma"
35+ PYRAMID_SIGMA_INV = "🔬pyramid-sigma inverse"
36+ PYRAMID_SIGMA_INV_DEPR = "pyramid-sigma inverse"
37+
38+ LIST = [PYRAMID , FLAT , OVERLAP_LINEAR , DELAYED_REVERSE_SAWTOOTH , PYRAMID_SIGMA , PYRAMID_SIGMA_INV , GAUSS_SIGMA , GAUSS_SIGMA_INV , RANDOM ]
39+ LIST_STATIC = [PYRAMID , RELATIVE , FLAT , OVERLAP_LINEAR , DELAYED_REVERSE_SAWTOOTH , PYRAMID_SIGMA , PYRAMID_SIGMA_INV , GAUSS_SIGMA , GAUSS_SIGMA_INV , RANDOM ]
3240
3341
3442class ContextType :
@@ -354,11 +362,11 @@ def get_context_windows(num_frames: int, opts: Union[ContextOptionsGroup, Contex
354362}
355363
356364
357- def get_context_weights (num_frames : int , fuse_method : str , sigma : Tensor = None ):
358- weights_func = FUSE_MAPPING .get (fuse_method , None )
365+ def get_context_weights (length : int , full_length : int , idxs : list [ int ], ctx_opts : ContextOptions , sigma : Tensor = None ):
366+ weights_func = FUSE_MAPPING .get (ctx_opts . fuse_method , None )
359367 if not weights_func :
360- raise ValueError (f"Unknown fuse_method '{ fuse_method } '." )
361- return weights_func (num_frames , sigma = sigma )
368+ raise ValueError (f"Unknown fuse_method '{ ctx_opts . fuse_method } '." )
369+ return weights_func (length , sigma = sigma , ctx_opts = ctx_opts , full_length = full_length , idxs = idxs )
362370
363371
364372def create_weights_flat (length : int , ** kwargs ) -> list [float ]:
@@ -376,6 +384,20 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
376384 weight_sequence = list (range (1 , max_weight , 1 )) + [max_weight ] + list (range (max_weight - 1 , 0 , - 1 ))
377385 return weight_sequence
378386
387+ def create_weights_overlap_linear (length : int , full_length : int , idxs : list [int ], ctx_opts : ContextOptions , ** kwargs ):
388+ # based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
389+ # only expected overlap is given different weights
390+ weights_torch = torch .ones ((length ))
391+ # blend left-side on all except first window
392+ if min (idxs ) > 0 :
393+ ramp_up = torch .linspace (1e-37 , 1 , ctx_opts .context_overlap )
394+ weights_torch [:ctx_opts .context_overlap ] = ramp_up
395+ # blend right-side on all except last window
396+ if max (idxs ) < full_length - 1 :
397+ ramp_down = torch .linspace (1 , 1e-37 , ctx_opts .context_overlap )
398+ weights_torch [- ctx_opts .context_overlap :] = ramp_down
399+ return weights_torch
400+
379401def create_weights_random (length : int , ** kwargs ) -> list [float ]:
380402 if length % 2 == 0 :
381403 max_weight = length // 2
@@ -454,12 +476,20 @@ def create_weights_delayed_reverse_sawtooth(length: int, **kwargs) -> list[float
454476 ContextFuseMethod .FLAT : create_weights_flat ,
455477 ContextFuseMethod .PYRAMID : create_weights_pyramid ,
456478 ContextFuseMethod .RELATIVE : create_weights_pyramid ,
479+ ContextFuseMethod .OVERLAP_LINEAR : create_weights_overlap_linear ,
480+ # experimental
457481 ContextFuseMethod .GAUSS_SIGMA : create_weights_gauss_sigma ,
482+ ContextFuseMethod .GAUSS_SIGMA_DEPR : create_weights_gauss_sigma ,
458483 ContextFuseMethod .GAUSS_SIGMA_INV : create_weights_gauss_sigma_inv ,
484+ ContextFuseMethod .GAUSS_SIGMA_INV_DEPR : create_weights_gauss_sigma_inv ,
459485 ContextFuseMethod .RANDOM : create_weights_random ,
486+ ContextFuseMethod .RANDOM_DEPR : create_weights_random ,
460487 ContextFuseMethod .DELAYED_REVERSE_SAWTOOTH : create_weights_delayed_reverse_sawtooth ,
488+ ContextFuseMethod .DELAYED_REVERSE_SAWTOOTH_DEPR : create_weights_delayed_reverse_sawtooth ,
461489 ContextFuseMethod .PYRAMID_SIGMA : create_weights_pyramid_sigma ,
490+ ContextFuseMethod .PYRAMID_SIGMA_DEPR : create_weights_pyramid_sigma ,
462491 ContextFuseMethod .PYRAMID_SIGMA_INV : create_weights_pyramid_sigma_inv ,
492+ ContextFuseMethod .PYRAMID_SIGMA_INV_DEPR : create_weights_pyramid_sigma_inv ,
463493}
464494
465495
0 commit comments