Skip to content

Commit 878db3a

Browse files
Implement the Ovis image model. (#11030)
1 parent 30c259c commit 878db3a

File tree

8 files changed

+182
-35
lines changed

8 files changed

+182
-35
lines changed

comfy/ldm/chroma/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ class ChromaParams:
4040
out_dim: int
4141
hidden_dim: int
4242
n_layers: int
43-
43+
txt_ids_dims: list
44+
vec_in_dim: int
4445

4546

4647

comfy/ldm/flux/layers.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,35 @@ def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=N
5757
def forward(self, x: Tensor) -> Tensor:
5858
return self.out_layer(self.silu(self.in_layer(x)))
5959

60+
class YakMLP(nn.Module):
61+
def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
62+
super().__init__()
63+
self.hidden_size = hidden_size
64+
self.intermediate_size = intermediate_size
65+
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
66+
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
67+
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
68+
self.act_fn = nn.SiLU()
69+
70+
def forward(self, x: Tensor) -> Tensor:
71+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
72+
return down_proj
73+
74+
def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
75+
if yak_mlp:
76+
return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
77+
if mlp_silu_act:
78+
return nn.Sequential(
79+
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
80+
SiLUActivation(),
81+
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
82+
)
83+
else:
84+
return nn.Sequential(
85+
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
86+
nn.GELU(approximate="tanh"),
87+
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
88+
)
6089

6190
class RMSNorm(torch.nn.Module):
6291
def __init__(self, dim: int, dtype=None, device=None, operations=None):
@@ -140,7 +169,7 @@ def forward(self, x: Tensor) -> Tensor:
140169

141170

142171
class DoubleStreamBlock(nn.Module):
143-
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
172+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
144173
super().__init__()
145174

146175
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -156,18 +185,7 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
156185

157186
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
158187

159-
if mlp_silu_act:
160-
self.img_mlp = nn.Sequential(
161-
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
162-
SiLUActivation(),
163-
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
164-
)
165-
else:
166-
self.img_mlp = nn.Sequential(
167-
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
168-
nn.GELU(approximate="tanh"),
169-
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
170-
)
188+
self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
171189

172190
if self.modulation:
173191
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
@@ -177,18 +195,7 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
177195

178196
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
179197

180-
if mlp_silu_act:
181-
self.txt_mlp = nn.Sequential(
182-
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
183-
SiLUActivation(),
184-
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
185-
)
186-
else:
187-
self.txt_mlp = nn.Sequential(
188-
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
189-
nn.GELU(approximate="tanh"),
190-
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
191-
)
198+
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
192199

193200
self.flipped_img_txt = flipped_img_txt
194201

@@ -275,6 +282,7 @@ def __init__(
275282
modulation=True,
276283
mlp_silu_act=False,
277284
bias=True,
285+
yak_mlp=False,
278286
dtype=None,
279287
device=None,
280288
operations=None
@@ -288,12 +296,17 @@ def __init__(
288296
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
289297

290298
self.mlp_hidden_dim_first = self.mlp_hidden_dim
299+
self.yak_mlp = yak_mlp
291300
if mlp_silu_act:
292301
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
293302
self.mlp_act = SiLUActivation()
294303
else:
295304
self.mlp_act = nn.GELU(approximate="tanh")
296305

306+
if self.yak_mlp:
307+
self.mlp_hidden_dim_first *= 2
308+
self.mlp_act = nn.SiLU()
309+
297310
# qkv and mlp_in
298311
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
299312
# proj and mlp_out
@@ -325,7 +338,10 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation
325338
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
326339
del q, k, v
327340
# compute activation in mlp stream, cat again and run second linear layer
328-
mlp = self.mlp_act(mlp)
341+
if self.yak_mlp:
342+
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
343+
else:
344+
mlp = self.mlp_act(mlp)
329345
output = self.linear2(torch.cat((attn, mlp), 2))
330346
x += apply_mod(output, mod.gate, None, modulation_dims)
331347
if x.dtype == torch.float16:

comfy/ldm/flux/model.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
MLPEmbedder,
1616
SingleStreamBlock,
1717
timestep_embedding,
18-
Modulation
18+
Modulation,
19+
RMSNorm
1920
)
2021

2122
@dataclass
@@ -34,11 +35,14 @@ class FluxParams:
3435
patch_size: int
3536
qkv_bias: bool
3637
guidance_embed: bool
38+
txt_ids_dims: list
3739
global_modulation: bool = False
3840
mlp_silu_act: bool = False
3941
ops_bias: bool = True
4042
default_ref_method: str = "offset"
4143
ref_index_scale: float = 1.0
44+
yak_mlp: bool = False
45+
txt_norm: bool = False
4246

4347

4448
class Flux(nn.Module):
@@ -76,6 +80,11 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
7680
)
7781
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
7882

83+
if params.txt_norm:
84+
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
85+
else:
86+
self.txt_norm = None
87+
7988
self.double_blocks = nn.ModuleList(
8089
[
8190
DoubleStreamBlock(
@@ -86,6 +95,7 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
8695
modulation=params.global_modulation is False,
8796
mlp_silu_act=params.mlp_silu_act,
8897
proj_bias=params.ops_bias,
98+
yak_mlp=params.yak_mlp,
8999
dtype=dtype, device=device, operations=operations
90100
)
91101
for _ in range(params.depth)
@@ -94,7 +104,7 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
94104

95105
self.single_blocks = nn.ModuleList(
96106
[
97-
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
107+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
98108
for _ in range(params.depth_single_blocks)
99109
]
100110
)
@@ -150,6 +160,8 @@ def forward_orig(
150160
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
151161
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
152162

163+
if self.txt_norm is not None:
164+
txt = self.txt_norm(txt)
153165
txt = self.txt_in(txt)
154166

155167
vec_orig = vec
@@ -332,8 +344,9 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None
332344

333345
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
334346

335-
if len(self.params.axes_dim) == 4: # Flux 2
336-
txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
347+
if len(self.params.txt_ids_dims) > 0:
348+
for i in self.params.txt_ids_dims:
349+
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
337350

338351
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
339352
out = out[:, :img_tokens]

comfy/model_detection.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
208208
dit_config["theta"] = 2000
209209
dit_config["out_channels"] = 128
210210
dit_config["global_modulation"] = True
211-
dit_config["vec_in_dim"] = None
212211
dit_config["mlp_silu_act"] = True
213212
dit_config["qkv_bias"] = False
214213
dit_config["ops_bias"] = False
215214
dit_config["default_ref_method"] = "index"
216215
dit_config["ref_index_scale"] = 10.0
216+
dit_config["txt_ids_dims"] = [3]
217217
patch_size = 1
218218
else:
219219
dit_config["image_model"] = "flux"
@@ -223,6 +223,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
223223
dit_config["theta"] = 10000
224224
dit_config["out_channels"] = 16
225225
dit_config["qkv_bias"] = True
226+
dit_config["txt_ids_dims"] = []
226227
patch_size = 2
227228

228229
dit_config["in_channels"] = 16
@@ -245,6 +246,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
245246
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
246247
if vec_in_key in state_dict_keys:
247248
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
249+
else:
250+
dit_config["vec_in_dim"] = None
248251

249252
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
250253
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
@@ -270,6 +273,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
270273
dit_config["nerf_embedder_dtype"] = torch.float32
271274
else:
272275
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
276+
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
277+
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
278+
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
279+
dit_config["txt_ids_dims"] = [1, 2]
280+
273281
return dit_config
274282

275283
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview

comfy/sd.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import comfy.text_encoders.qwen_image
5454
import comfy.text_encoders.hunyuan_image
5555
import comfy.text_encoders.z_image
56+
import comfy.text_encoders.ovis
5657

5758
import comfy.model_patcher
5859
import comfy.lora
@@ -956,6 +957,7 @@ class CLIPType(Enum):
956957
QWEN_IMAGE = 18
957958
HUNYUAN_IMAGE = 19
958959
HUNYUAN_VIDEO_15 = 20
960+
OVIS = 21
959961

960962

961963
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -987,6 +989,7 @@ class TEModel(Enum):
987989
MISTRAL3_24B = 14
988990
MISTRAL3_24B_PRUNED_FLUX2 = 15
989991
QWEN3_4B = 16
992+
QWEN3_2B = 17
990993

991994

992995
def detect_te_model(sd):
@@ -1020,9 +1023,12 @@ def detect_te_model(sd):
10201023
if weight.shape[0] == 512:
10211024
return TEModel.QWEN25_7B
10221025
if "model.layers.0.post_attention_layernorm.weight" in sd:
1023-
if 'model.layers.0.self_attn.q_norm.weight' in sd:
1024-
return TEModel.QWEN3_4B
10251026
weight = sd['model.layers.0.post_attention_layernorm.weight']
1027+
if 'model.layers.0.self_attn.q_norm.weight' in sd:
1028+
if weight.shape[0] == 2560:
1029+
return TEModel.QWEN3_4B
1030+
elif weight.shape[0] == 2048:
1031+
return TEModel.QWEN3_2B
10261032
if weight.shape[0] == 5120:
10271033
if "model.layers.39.post_attention_layernorm.weight" in sd:
10281034
return TEModel.MISTRAL3_24B
@@ -1150,6 +1156,9 @@ class EmptyClass:
11501156
elif te_model == TEModel.QWEN3_4B:
11511157
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
11521158
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
1159+
elif te_model == TEModel.QWEN3_2B:
1160+
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
1161+
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
11531162
else:
11541163
# clip_l
11551164
if clip_type == CLIPType.SD3:

comfy/text_encoders/llama.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,28 @@ class Qwen3_4BConfig:
100100
rope_scale = None
101101
final_norm: bool = True
102102

103+
@dataclass
104+
class Ovis25_2BConfig:
105+
vocab_size: int = 151936
106+
hidden_size: int = 2048
107+
intermediate_size: int = 6144
108+
num_hidden_layers: int = 28
109+
num_attention_heads: int = 16
110+
num_key_value_heads: int = 8
111+
max_position_embeddings: int = 40960
112+
rms_norm_eps: float = 1e-6
113+
rope_theta: float = 1000000.0
114+
transformer_type: str = "llama"
115+
head_dim = 128
116+
rms_norm_add = False
117+
mlp_activation = "silu"
118+
qkv_bias = False
119+
rope_dims = None
120+
q_norm = "gemma3"
121+
k_norm = "gemma3"
122+
rope_scale = None
123+
final_norm: bool = True
124+
103125
@dataclass
104126
class Qwen25_7BVLI_Config:
105127
vocab_size: int = 152064
@@ -542,6 +564,15 @@ def __init__(self, config_dict, dtype, device, operations):
542564
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
543565
self.dtype = dtype
544566

567+
class Ovis25_2B(BaseLlama, torch.nn.Module):
568+
def __init__(self, config_dict, dtype, device, operations):
569+
super().__init__()
570+
config = Ovis25_2BConfig(**config_dict)
571+
self.num_layers = config.num_hidden_layers
572+
573+
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
574+
self.dtype = dtype
575+
545576
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
546577
def __init__(self, config_dict, dtype, device, operations):
547578
super().__init__()

0 commit comments

Comments
 (0)