Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from .cache_methods.nodes_cache import NODE_CLASS_MAPPINGS as NODE_CACHE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as NODE_CACHE_DISPLAY_NAME_MAPPINGS
from .nodes_deprecated import NODE_CLASS_MAPPINGS as DEPRECATED_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as DEPRECATED_NODE_DISPLAY_NAME_MAPPINGS



try:
from .qwen.qwen import NODE_CLASS_MAPPINGS as QWEN_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as QWEN_NODE_DISPLAY_NAME_MAPPINGS
except Exception as e:
Expand All @@ -23,18 +25,35 @@

try:
from .fantasyportrait.nodes import NODE_CLASS_MAPPINGS as FANTASYPORTRAIT_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as FANTASYPORTRAIT_NODE_DISPLAY_NAME_MAPPINGS
from .fantasyportrait.nodes_multi import NODE_CLASS_MAPPINGS as FANTASYPORTRAIT_MULTI_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as FANTASYPORTRAIT_MULTI_NODE_DISPLAY_NAME_MAPPINGS
except Exception as e:
print(f"FantasyPortrait nodes not available due to error in importing them: {e}")
FANTASYPORTRAIT_NODE_CLASS_MAPPINGS = {}
FANTASYPORTRAIT_NODE_DISPLAY_NAME_MAPPINGS = {}

try:
from .fantasyportrait.fp_mask_nodes import (
NODE_CLASS_MAPPINGS as FP_MASK_NODE_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS as FP_MASK_NODE_DISPLAY_NAME_MAPPINGS,
)
except Exception as e:
print(f"FP mask nodes not available due to error: {e}")
FP_MASK_NODE_CLASS_MAPPINGS = {}
FP_MASK_NODE_DISPLAY_NAME_MAPPINGS = {}


try:
from .unianimate.nodes import NODE_CLASS_MAPPINGS as UNIANIMATE_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as UNIANIMATE_NODE_DISPLAY_NAME_MAPPINGS
except Exception as e:
print(f"UniAnimate nodes not available due to error in importing them: {e}")
UNIANIMATE_NODE_CLASS_MAPPINGS = {}
UNIANIMATE_NODE_DISPLAY_NAME_MAPPINGS = {}

NODE_CLASS_MAPPINGS.update(FANTASYPORTRAIT_MULTI_NODE_CLASS_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(FANTASYPORTRAIT_MULTI_NODE_DISPLAY_NAME_MAPPINGS)
NODE_CLASS_MAPPINGS.update(FP_MASK_NODE_CLASS_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(FP_MASK_NODE_DISPLAY_NAME_MAPPINGS)

NODE_CLASS_MAPPINGS.update(RECAM_MASTER_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(UNIANIMATE_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(SKYREELS_NODE_CLASS_MAPPINGS)
Expand Down

Large diffs are not rendered by default.

47 changes: 47 additions & 0 deletions fantasyportrait/auto_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os, logging, torch
from .fp_mask_ctx import enable_mask_globally
from .fp_ids_resolver import (
make_wan_ids_resolver_dynamic,
make_full_frame_masks,
)

log = logging.getLogger("wanvw.fp")

def _left_right_masks(batch:int, H:int, W:int):
left = torch.zeros((batch,1,H,W)); left[:,:,:, :W//2] = 1
right = torch.zeros((batch,1,H,W)); right[:,:,:, W//2:] = 1
return [left, right]

if os.getenv("WANVW_FP_MASK","0") == "1":
try:
B = int(os.getenv("WANVW_BATCH","1"))
H = int(os.getenv("WANVW_IMAGE_H","720"))
W = int(os.getenv("WANVW_IMAGE_W","720"))
N = int(os.getenv("WANVW_NUM_PEOPLE","1"))
driver_pp = int(os.getenv("WANVW_DRIVER_PER_PERSON","512"))
mode = os.getenv("WANVW_MASK_MODE","strict") # "strict" | "soft"
layout = os.getenv("WANVW_MASK_LAYOUT","full") # "full" | "leftright"
latent_down = int(os.getenv("WANVW_LATENT_DOWN","16")) # typical = 16

# Build masks (quick smoke test or full-frame per-person)
if layout == "leftright" and N == 2:
masks = _left_right_masks(B, H, W)
else:
masks = make_full_frame_masks(B, H, W, num_people=N)

driver_lengths = [driver_pp] * N

# Resolver: prefer dynamic rectangular mapping using H×W
ids_resolver = make_wan_ids_resolver_dynamic(
masks, driver_lengths,
image_h=H, image_w=W, latent_down=latent_down
)

enable_mask_globally(
ids_resolver, mode=mode,
soft_bias=float(os.getenv("WANVW_SOFT_BIAS","4.0")),
allow_global=True, debug=int(os.getenv("WANVW_MASK_DEBUG","2"))
)
log.info(f"[AUTO_MASK] Enabled: N={N}, driver_per_person={driver_pp}, layout={layout}, mode={mode}, out={W}x{H}, down={latent_down}")
except Exception as e:
log.exception(f"[AUTO_MASK] failed to enable global masking: {e}")
6 changes: 6 additions & 0 deletions fantasyportrait/auto_probe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# File: fantasyportrait/auto_probe.py
# Auto-enable global probe when WANVW_FP_PROBE=1 is set.
import os
if os.getenv("WANVW_FP_PROBE", "0") == "1":
from .fp_mask_ctx import enable_probe_globally
enable_probe_globally(True)
166 changes: 166 additions & 0 deletions fantasyportrait/fp_ids_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from __future__ import annotations
import os
from typing import List, Tuple, Optional, Dict

import torch
import torch.nn.functional as F

from .fp_mask_utils import (
downsample_masks_to_tokens,
driver_token_ids_from_lengths,
)

try:
import torch._dynamo as _dynamo
_dynamo_disable = _dynamo.disable
except Exception:
def _dynamo_disable(fn): return fn # no-op fallback


def _env_int(name: str, default: int) -> int:
try:
return int(os.environ.get(name, default))
except Exception:
return int(default)


@_dynamo_disable
def make_full_frame_masks(
layout: str,
B: int,
H: int,
W: int,
*,
device: Optional[torch.device] = None,
) -> List[torch.Tensor]:
"""Return list of [B,1,H,W] masks for quick tests."""
device = device or torch.device("cpu")
ly = layout.lower()
if ly in ("leftright", "left-right", "lr"):
left = torch.zeros((B, 1, H, W), device=device)
right = torch.zeros((B, 1, H, W), device=device)
left[:, :, :, : W // 2] = 1.0
right[:, :, :, W // 2 :] = 1.0
return [left, right]
if ly in ("topbottom", "top-bottom", "tb"):
top = torch.zeros((B, 1, H, W), device=device)
bot = torch.zeros((B, 1, H, W), device=device)
top[:, :, : H // 2, :] = 1.0
bot[:, :, H // 2 :, :] = 1.0
return [top, bot]
raise ValueError(f"Unknown layout '{layout}'")


@_dynamo_disable
def make_wan720_ids_resolver(
masks_bchw: List[torch.Tensor], # each [B or 1, 1, H, W]
driver_lengths: List[int], # e.g., [97, 97]
*,
global_prefix: int = 0,
global_suffix: int = 0,
latent_down: Optional[int] = None,
image_h: Optional[int] = None,
image_w: Optional[int] = None,
):
"""
Returns ids_resolver(qshape, kshape, info) -> (ids_q[B,Tq], ids_k[B,Tk]).
Handles dynamic batch B (Wan’s frame chunking) and caches per-(B, Tq/Tk).
Only activates when Tq == (H/down)*(W/down) and Tk == prefix+sum(driver_lengths)+suffix.
"""
H = image_h or _env_int("WANVW_IMAGE_H", 720)
W = image_w or _env_int("WANVW_IMAGE_W", 1280)
down = latent_down or _env_int("WANVW_LATENT_DOWN", 16)
Ht, Wt = H // down, W // down
Tpf = Ht * Wt # tokens per frame

ids_q_cache: Dict[Tuple[int, int], torch.Tensor] = {}
ids_k_cache: Dict[Tuple[int, int], torch.Tensor] = {}

base_masks = [m.clone() for m in masks_bchw]

def _replicate_masks_for_B(B: int, device: torch.device) -> List[torch.Tensor]:
out: List[torch.Tensor] = []
for m in base_masks:
t = m
if t.dim() == 3:
t = t.unsqueeze(1) # [B,1,H,W]
if t.shape[0] != B:
if t.shape[0] == 1:
t = t.repeat(B, 1, 1, 1) # no shared storage
else:
t = t[:B]
out.append(t.to(device=device, dtype=torch.float32))
return out

@_dynamo_disable
def _resolve(qshape: Tuple[int, int, int, int], kshape: Tuple[int, int, int, int], info: dict):
B, _, Tq, _ = qshape
Tk = kshape[2]

expected_Tk = int(global_prefix) + sum(int(x) for x in driver_lengths) + int(global_suffix)
if Tq != Tpf or Tk != expected_Tk:
return (None, None)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

key_q = (int(B), int(Tq))
ids_q = ids_q_cache.get(key_q)
if ids_q is None:
ms = _replicate_masks_for_B(int(B), device)
ids_q = downsample_masks_to_tokens(ms, Ht, Wt) # [B, Tpf]
ids_q_cache[key_q] = ids_q
print(f"[IDS] ids_q built: B={B} Ht={Ht} Wt={Wt} -> {tuple(ids_q.shape)}")

key_k = (int(B), int(Tk))
ids_k = ids_k_cache.get(key_k)
if ids_k is None:
ids_k = driver_token_ids_from_lengths(
driver_lengths,
global_prefix=global_prefix,
global_suffix=global_suffix,
batch=int(B),
device=device,
)
ids_k_cache[key_k] = ids_k
print(f"[IDS] ids_k built: B={B} Tk={Tk} -> {tuple(ids_k.shape)}")

return (ids_q, ids_k)

return _resolve


# --- Back-compat alias expected by some FantasyPortrait nodes ---
@_dynamo_disable
def make_wan_ids_resolver_dynamic(
masks_bchw: List[torch.Tensor],
driver_lengths: List[int],
**kwargs,
):
"""
Backward-compatible alias. Calls make_wan720_ids_resolver(...) internally.
"""
return make_wan720_ids_resolver(masks_bchw, driver_lengths, **kwargs)


@_dynamo_disable
def make_env_rect_ids_resolver():
H = _env_int("WANVW_IMAGE_H", 720)
W = _env_int("WANVW_IMAGE_W", 1280)
down = _env_int("WANVW_LATENT_DOWN", 16)
layout = os.environ.get("WANVW_MASK_LAYOUT", "leftright")
n_people = _env_int("WANVW_NUM_PEOPLE", 2)
driver_per = _env_int("WANVW_DRIVER_PER_PERSON", 97)
# dummy B for mask creation; real B comes from qshape during resolve
masks = make_full_frame_masks(layout, 1, H, W)
driver_lengths = [driver_per] * n_people
return make_wan720_ids_resolver(
masks, driver_lengths, latent_down=down, image_h=H, image_w=W
)


__all__ = [
"make_full_frame_masks",
"make_wan720_ids_resolver",
"make_wan_ids_resolver_dynamic", # <— the alias FantasyPortrait nodes import
"make_env_rect_ids_resolver",
]
18 changes: 18 additions & 0 deletions fantasyportrait/fp_integration_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations
import logging
from .fp_mask_ctx import fp_mask_session, enable_probe_logging

log = logging.getLogger("wanvw.fp")

def run_with_mask(callable_fn, ids_resolver, *args, **kwargs):
"""
Wrap ANY existing sampling call and enable masked attention
(plus probe logs for visibility).
"""
enable_probe_logging(logging.INFO)
log.info("[FP-MASK] starting masked session (probe=True, mask=True, strict)")
with fp_mask_session(
probe=True, mask=True, mode="strict", soft_bias=4.0,
allow_global=True, debug=2, ids_resolver=ids_resolver
):
return callable_fn(*args, **kwargs)
15 changes: 15 additions & 0 deletions fantasyportrait/fp_integration_probe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from __future__ import annotations
import logging
from .fp_mask_ctx import fp_mask_session, enable_probe_logging

log = logging.getLogger("wanvw.fp")

def run_with_probe(callable_fn, *args, **kwargs):
"""
Wrap ANY existing sampling/inference call.
Probe-only (no behavior change): prints [PROBE:sdpa] lines.
"""
enable_probe_logging(logging.INFO)
log.info("[FP-PROBE] starting probe session (mask=False, debug=1)")
with fp_mask_session(probe=True, mask=False, debug=1):
return callable_fn(*args, **kwargs)
Loading