Skip to content

Commit 3e6f10b

Browse files
authored
Add SeedVR2 7B support (#352)
1 parent 3bc2214 commit 3e6f10b

File tree

14 files changed

+415
-75
lines changed

14 files changed

+415
-75
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ MFLUX supports the following model families. They have different strengths and w
122122
|[Z-Image](src/mflux/models/z_image/README.md) | Nov 2025 | 6B | Distilled & Base | Yes | Best all-rounder: fast, small, very good quality and realism. |
123123
|[FLUX.2](src/mflux/models/flux2/README.md) | Jan 2026 | 4B & 9B | Distilled & Base | Yes | Fastest + smallest with very good qaility and edit capabilities. |
124124
|[FIBO](src/mflux/models/fibo/README.md) | Oct 2025 | 8B | Base | No | Very good JSON-based prompt understanding and editability, medium speed |
125-
|[SeedVR2](src/mflux/models/seedvr2/README.md) | Jun 2025 | 3B || No | Best upscaling model. |
125+
|[SeedVR2](src/mflux/models/seedvr2/README.md) | Jun 2025 | 3B & 7B || No | Best upscaling model. |
126126
|[Qwen Image](src/mflux/models/qwen/README.md) | Aug 2025+ | 20B | Base | No | Large model (slower); strong prompt understanding and world knowledge. Has edit capabilities |
127127
|[Depth Pro](src/mflux/models/depth_pro/README.md) | Oct 2024 ||| No | Very fast and accurate depth estimation model from Apple. |
128128
|[FLUX.1](src/mflux/models/flux/README.md) | Aug 2024 | 12B | Distilled & Base | No (legacy) | Legacy option with decent quality. Has edit capabilities with 'Kontext' model and upscaling support via ControlNet |

src/mflux/models/common/config/model_config.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ def z_image() -> "ModelConfig":
142142
def seedvr2_3b() -> "ModelConfig":
143143
return AVAILABLE_MODELS["seedvr2-3b"]
144144

145+
@staticmethod
146+
@lru_cache
147+
def seedvr2_7b() -> "ModelConfig":
148+
return AVAILABLE_MODELS["seedvr2-7b"]
149+
145150
def x_embedder_input_dim(self) -> int:
146151
if "Fill" in self.model_name:
147152
return 384
@@ -468,4 +473,28 @@ def from_name(
468473
supports_guidance=True,
469474
requires_sigma_shift=None,
470475
),
476+
"seedvr2-7b": ModelConfig(
477+
priority=21,
478+
aliases=["seedvr2-7b", "seedvr2-7B"],
479+
model_name="numz/SeedVR2_comfyUI",
480+
base_model=None,
481+
controlnet_model=None,
482+
custom_transformer_model=None,
483+
num_train_steps=None,
484+
max_sequence_length=None,
485+
supports_guidance=True,
486+
requires_sigma_shift=None,
487+
transformer_overrides={
488+
"vid_dim": 3072,
489+
"heads": 24,
490+
"num_layers": 36,
491+
"mm_layers": 36,
492+
"rope_dim": 64,
493+
"rope_on_text": False,
494+
"rope_freqs_for": "pixel",
495+
"mlp_type": "normal",
496+
"use_output_ada": False,
497+
"last_layer_vid_only": False,
498+
},
499+
),
471500
}

src/mflux/models/seedvr2/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ SeedVR2 is more recent and the preferred method for high-fidelity upscaling and
1111

1212
```sh
1313
mflux-upscale-seedvr2 \
14+
--model seedvr2-7b \
1415
--image-path "input.png" \
1516
--resolution 2160 \
1617
--softness 0.5
@@ -34,7 +35,7 @@ image.save("input_upscaled.png")
3435
```
3536
</details>
3637

37-
This will upscale the image such that the shortest side is 2160 pixels while maintaining the aspect ratio.
38+
This will upscale the image such that the shortest side is 2160 pixels while maintaining the aspect ratio. If `--model` is omitted, MFLUX defaults to `seedvr2-3b`.
3839

3940
Instead of specifying a target resolution, you can also use `--resolution 2x` or `--resolution 3x` to upscale by a factor of 2 or 3 respectively.
4041

@@ -145,4 +146,3 @@ image.save("image_upscaled.png")
145146
</details>
146147

147148
</details>
148-

src/mflux/models/seedvr2/cli/seedvr2_upscale.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,32 @@ def _is_image_file(path: Path) -> bool:
2323
return path.is_file() and path.suffix.lower() in SUPPORTED_IMAGE_SUFFIXES
2424

2525

26+
def _resolve_seedvr2_model(model_arg: str | None, model_path: str | None) -> tuple[ModelConfig, str | None]:
27+
if model_arg is None:
28+
return ModelConfig.seedvr2_3b(), model_path
29+
30+
normalized = model_arg.lower()
31+
if normalized in {"seedvr2", "seedvr2-3b"}:
32+
return ModelConfig.seedvr2_3b(), None
33+
if normalized in {"seedvr2-7b"}:
34+
return ModelConfig.seedvr2_7b(), None
35+
36+
if model_path is not None:
37+
path = Path(model_path).expanduser()
38+
if path.is_dir():
39+
has_3b = (path / "seedvr2_ema_3b_fp16.safetensors").exists()
40+
has_7b = (path / "seedvr2_ema_7b_fp16.safetensors").exists()
41+
if has_7b and not has_3b:
42+
return ModelConfig.seedvr2_7b(), model_path
43+
if has_3b and not has_7b:
44+
return ModelConfig.seedvr2_3b(), model_path
45+
46+
source = (model_path or model_arg).lower()
47+
if "seedvr2_ema_7b" in source or "seedvr2-7b" in source:
48+
return ModelConfig.seedvr2_7b(), model_path
49+
return ModelConfig.seedvr2_3b(), model_path
50+
51+
2652
def _expand_image_paths(image_paths: list[Path]) -> list[Path]:
2753
expanded: list[Path] = []
2854
for image_path in image_paths:
@@ -53,11 +79,13 @@ def main():
5379
print("No images to upscale.")
5480
return
5581

82+
model_config, resolved_model_path = _resolve_seedvr2_model(args.model, args.model_path)
83+
5684
# 3. Load the SeedVR2 model
5785
model = SeedVR2(
5886
quantize=args.quantize,
59-
model_path=args.model_path,
60-
model_config=ModelConfig.seedvr2_3b(),
87+
model_path=resolved_model_path,
88+
model_config=model_config,
6189
)
6290

6391
# 4. Register callbacks

src/mflux/models/seedvr2/model/seedvr2_transformer/attention.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def __init__(
1616
qk_bias: bool = False,
1717
qk_norm_eps: float = 1e-5,
1818
rope_dim: int = 128,
19+
rope_freqs_for: str = "lang",
20+
rope_on_text: bool = True,
1921
shared_weights: bool = False,
2022
window: tuple[int, int, int] = (4, 3, 3),
2123
shift: bool = False,
@@ -27,6 +29,7 @@ def __init__(
2729
self.scale = head_dim**-0.5
2830
self.window = window
2931
self.shift = shift
32+
self.rope_on_text = rope_on_text
3033

3134
inner_dim = heads * head_dim
3235

@@ -46,7 +49,7 @@ def __init__(
4649
self.norm_q_txt = RMSNorm(head_dim, eps=qk_norm_eps)
4750
self.norm_k_txt = RMSNorm(head_dim, eps=qk_norm_eps)
4851

49-
self.rope = RoPEModule(dim=rope_dim)
52+
self.rope = RoPEModule(dim=rope_dim, freqs_for=rope_freqs_for)
5053

5154
def __call__(self, vid, txt, vid_shape, txt_shape):
5255
B, L, Bt, Lt = vid.shape[0], vid.shape[1], txt.shape[0], txt.shape[1]
@@ -67,14 +70,21 @@ def __call__(self, vid, txt, vid_shape, txt_shape):
6770
q_txt_rep, k_txt_rep, v_txt_rep = qkv_t_rep[:, 0], qkv_t_rep[:, 1], qkv_t_rep[:, 2]
6871

6972
# 3. Apply RoPE
70-
q_vid, k_vid, q_txt_rep, k_txt_rep = self.rope(
71-
vid_q=q_vid,
72-
vid_k=k_vid,
73-
vid_shape=partitioner.window_shapes,
74-
txt_q=q_txt_rep,
75-
txt_k=k_txt_rep,
76-
txt_shape=mx.repeat(txt_shape, mx.array(counts), axis=0),
77-
)
73+
if self.rope_on_text:
74+
q_vid, k_vid, q_txt_rep, k_txt_rep = self.rope(
75+
vid_q=q_vid,
76+
vid_k=k_vid,
77+
vid_shape=partitioner.window_shapes,
78+
txt_q=q_txt_rep,
79+
txt_k=k_txt_rep,
80+
txt_shape=mx.repeat(txt_shape, mx.array(counts), axis=0),
81+
)
82+
else:
83+
q_vid, k_vid = self.rope(
84+
vid_q=q_vid,
85+
vid_k=k_vid,
86+
vid_shape=partitioner.window_shapes,
87+
)
7888

7989
# 4. Attention
8090
vid_lens = mx.prod(partitioner.window_shapes, axis=1)

src/mflux/models/seedvr2/model/seedvr2_transformer/mm_swiglu.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import mlx.core as mx
22
from mlx import nn
33

4-
from mflux.models.seedvr2.model.seedvr2_transformer.swiglu_mlp import SwiGLUMLP
4+
from mflux.models.seedvr2.model.seedvr2_transformer.swiglu_mlp import GELUMLP, SwiGLUMLP
55

66

77
class MMSwiGLU(nn.Module):
@@ -12,17 +12,25 @@ def __init__(
1212
expand_ratio: int = 4,
1313
shared_weights: bool = False,
1414
is_last_layer: bool = False,
15+
mlp_type: str = "swiglu",
1516
):
1617
super().__init__()
1718
self.shared_weights = shared_weights
1819
self.is_last_layer = is_last_layer
20+
self.mlp_type = mlp_type
21+
22+
mlp_cls = SwiGLUMLP
23+
mlp_kwargs = {"expand_ratio": expand_ratio}
24+
if mlp_type == "normal":
25+
mlp_cls = GELUMLP
26+
mlp_kwargs["bias"] = True
1927

2028
if shared_weights:
21-
self.all = SwiGLUMLP(dim=vid_dim, expand_ratio=expand_ratio)
29+
self.all = mlp_cls(dim=vid_dim, **mlp_kwargs)
2230
else:
23-
self.vid = SwiGLUMLP(dim=vid_dim, expand_ratio=expand_ratio)
31+
self.vid = mlp_cls(dim=vid_dim, **mlp_kwargs)
2432
if not is_last_layer:
25-
self.txt = SwiGLUMLP(dim=txt_dim, expand_ratio=expand_ratio)
33+
self.txt = mlp_cls(dim=txt_dim, **mlp_kwargs)
2634

2735
def __call__(self, vid: mx.array, txt: mx.array) -> tuple[mx.array, mx.array]:
2836
if self.shared_weights:

0 commit comments

Comments
 (0)