Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions context_windows/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,18 @@ def get_total_steps(

def create_window_mask(noise_pred_context, c, latent_video_length, context_overlap, looped=False, window_type="linear"):
window_mask = torch.ones_like(noise_pred_context)

if window_type == "pyramid":

if window_type == "flashvsr":
# Special mode for FlashVSR: use overlap for context but don't blend
# First chunk: keep all frames (weight=1)
# Later chunks: mask out overlap region (weight=0), keep only new frames (weight=1)
if min(c) > 0: # Not the first chunk
# Set overlap region to 0 so these frames don't contribute to final result
window_mask[:, :context_overlap] = 0
# All other frames get weight 1 (no blending/ramping)
return window_mask

elif window_type == "pyramid":
# Create pyramid weights that peak in the middle
length = noise_pred_context.shape[1]
if length % 2 == 0:
Expand Down
108 changes: 102 additions & 6 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,7 +1654,7 @@ def INPUT_TYPES(s):
"verbose": ("BOOLEAN", {"default": False, "tooltip": "Print debug output"}),
},
"optional": {
"fuse_method": (["linear", "pyramid"], {"default": "linear", "tooltip": "Window weight function: linear=ramps at edges only, pyramid=triangular weights peaking in middle"}),
"fuse_method": (["linear", "pyramid", "flashvsr"], {"default": "linear", "tooltip": "Window weight function: linear=ramps at edges only, pyramid=triangular weights peaking in middle, flashvsr=no blending (use for FlashVSR upscaling)"}),
"reference_latent": ("LATENT", {"tooltip": "Image to be used as init for I2V models for windows where first frame is not the actual first frame. Mostly useful with MAGREF model"}),
}
}
Expand Down Expand Up @@ -1999,6 +1999,9 @@ def decode(self, vae, samples, enable_vae_tiling, tile_x, tile_y, tile_stride_x,

flashvsr_LQ_images = samples.get("flashvsr_LQ_images", None)

# Get context_options from samples dictionary (passed from WanVideoSampler)
context_options = samples.get("context_options", None)

vae.to(device)

latents = latents.to(device = device, dtype = vae.dtype)
Expand All @@ -2010,11 +2013,104 @@ def decode(self, vae, samples, enable_vae_tiling, tile_x, tile_y, tile_stride_x,
if drop_last:
latents = latents[:, :, :-1]

if type(vae).__name__ == "TAEHV":
images = vae.decode_video(latents.permute(0, 2, 1, 3, 4), cond=flashvsr_LQ_images.to(vae.dtype))[0].permute(1, 0, 2, 3)
images = torch.clamp(images, 0.0, 1.0)
images = images.permute(1, 2, 3, 0).cpu().float()
return (images,)
if type(vae).__name__ == "TAEHV":
# FlashVSR decoding with chunking for memory efficiency
# Convert context_frames from pixel to latent space for comparison
latent_context_frames_threshold = max(1, context_options["context_frames"] // 4) if context_options is not None else 999999

if context_options is not None and latents.shape[2] > latent_context_frames_threshold:
# Chunk the decoding with overlap for temporal continuity
context_frames = context_options["context_frames"]
context_overlap = context_options.get("context_overlap", 16)
num_frames = latents.shape[2]

# Work in latent space for chunking
latent_context_frames = max(1, context_frames // 4)
latent_overlap = max(1, context_overlap // 4)
stride = latent_context_frames - latent_overlap

# FlashVSR decoder outputs at PIXEL temporal resolution (4x latent)
# So we need to discard overlap in pixel space
pixel_overlap = context_overlap

log.info(f"Decoding FlashVSR with overlap: {latent_context_frames} latent frames per chunk, {latent_overlap} latent overlap ({pixel_overlap} pixel overlap), {num_frames} total latent frames")

decoded_frames = []

# Process chunks with overlap, but trim the overlap
for chunk_idx, start_idx in enumerate(range(0, num_frames, stride)):
end_idx = min(start_idx + latent_context_frames, num_frames)

# Extract latent chunk
chunk_latents = latents[:, :, start_idx:end_idx]

# Extract corresponding LQ images if they exist
chunk_LQ = None
if flashvsr_LQ_images is not None:
lq_start = start_idx * 4
lq_end = end_idx * 4
chunk_LQ = flashvsr_LQ_images[:, :, lq_start:lq_end].to(vae.dtype)

# Decode this chunk with sequential processing
# Output is at PIXEL temporal resolution (not latent!)
chunk_images = vae.decode_video(
chunk_latents.permute(0, 2, 1, 3, 4),
cond=chunk_LQ,
parallel=False, # Frame-by-frame within chunk
show_progress_bar=True
)[0].permute(1, 0, 2, 3)

# Keep only non-overlapping frames (discard in PIXEL space)
if chunk_idx == 0:
# First chunk: keep all frames
keep_frames = chunk_images
else:
# Calculate actual overlap based on decoder output
# Decoder doesn't always output exactly 4x latent frames
actual_pixel_frames = chunk_images.shape[1]
overlap_ratio = latent_overlap / latent_context_frames
actual_overlap = int(actual_pixel_frames * overlap_ratio)

# Discard the overlap frames
if chunk_images.shape[1] > actual_overlap:
keep_frames = chunk_images[:, actual_overlap:]
else:
keep_frames = chunk_images

decoded_frames.append(keep_frames.cpu())

# Log before cleanup
if chunk_idx == 0:
log.info(f"Decoded chunk {start_idx}-{end_idx} latent ({start_idx*4}-{end_idx*4} pixel), decoder output {chunk_images.shape[1]} frames, kept all")
else:
log.info(f"Decoded chunk {start_idx}-{end_idx} latent ({start_idx*4}-{end_idx*4} pixel), decoder output {chunk_images.shape[1]} frames, discarded {actual_overlap}, kept {decoded_frames[-1].shape[1]}")

# Clean up
del chunk_latents, chunk_images, keep_frames
if chunk_LQ is not None:
del chunk_LQ
mm.soft_empty_cache()

# Concatenate all chunks
images = torch.cat(decoded_frames, dim=1)
images = torch.clamp(images, 0.0, 1.0)
images = images.permute(1, 2, 3, 0).float()

del decoded_frames
mm.soft_empty_cache()

return (images,)
else:
# Single-pass decoding for short videos
images = vae.decode_video(
latents.permute(0, 2, 1, 3, 4),
cond=flashvsr_LQ_images.to(vae.dtype) if flashvsr_LQ_images is not None else None,
parallel=True
)[0].permute(1, 0, 2, 3)

images = torch.clamp(images, 0.0, 1.0)
images = images.permute(1, 2, 3, 0).cpu().float()
return (images,)
else:
if end_image is not None:
enable_vae_tiling = False
Expand Down
22 changes: 13 additions & 9 deletions nodes_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,14 +834,12 @@ def process(self, model, image_embeds, shift, steps, cfg, seed, scheduler, rifle
flashvsr_LQ_images = image_embeds.get("flashvsr_LQ_images", None)
flashvsr_strength = image_embeds.get("flashvsr_strength", 1.0)
if flashvsr_LQ_images is not None:
if flashvsr_LQ_images.shape[0] < num_frames + 4:
missing_frames = num_frames + 4 - flashvsr_LQ_images.shape[0]
last_frame = flashvsr_LQ_images[-1:].repeat(missing_frames, 1, 1, 1)
flashvsr_LQ_images = torch.cat([flashvsr_LQ_images, last_frame], dim=0)
LQ_images = flashvsr_LQ_images[:num_frames+4].unsqueeze(0).movedim(-1, 1).to(dtype) * 2 - 1
LQ_images = flashvsr_LQ_images.unsqueeze(0).movedim(-1, 1).to(device, dtype) * 2 - 1
if context_options is None:
flashvsr_LQ_latent = transformer.LQ_proj_in(LQ_images.to(device))
flashvsr_LQ_latent = transformer.LQ_proj_in(LQ_images)
log.info(f"flashvsr_LQ_latent: {flashvsr_LQ_latent[0].shape}")
if noise.shape[1] != 1:
noise = noise[:, :-1]
seq_len = math.ceil((noise.shape[2] * noise.shape[3]) / 4 * noise.shape[1])

latent = noise
Expand Down Expand Up @@ -1955,7 +1953,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i
end = c[-1] * 4 + 1 + 4
center_indices = torch.arange(start, end, 1)
center_indices = torch.clamp(center_indices, min=0, max=LQ_images.shape[2] - 1)
partial_flashvsr_LQ_images = LQ_images[:, :, center_indices].to(device)
partial_flashvsr_LQ_images = LQ_images[:, :, center_indices].to(device, dtype)
partial_flashvsr_LQ_latent = transformer.LQ_proj_in(partial_flashvsr_LQ_images)

if len(timestep.shape) != 1:
Expand Down Expand Up @@ -2999,7 +2997,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i
torch.cuda.reset_peak_memory_stats(device)
except:
pass
return ({
output_dict = {
"samples": latent.unsqueeze(0).cpu(),
"looped": is_looped,
"end_image": end_image if not fun_or_fl2v_model else None,
Expand All @@ -3010,7 +3008,13 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i
"cache_states": cache_states,
"latent_ovi_audio": latent_ovi.unsqueeze(0).transpose(1, 2).cpu() if latent_ovi is not None else None,
"flashvsr_LQ_images": LQ_images,
},{
}

# Only pass context_options if it's actually being used (not None)
if context_options is not None:
output_dict["context_options"] = context_options

return (output_dict, {
"samples": callback_latent.unsqueeze(0).cpu() if callback is not None else None,
})

Expand Down