Skip to content

Commit 8aea746

Browse files
Implement gemma 3 as a text encoder. (#10241)
Not useful yet.
1 parent 8c19910 commit 8aea746

File tree

4 files changed

+142
-28
lines changed

4 files changed

+142
-28
lines changed

comfy/model_detection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
365365
dit_config["patch_size"] = 2
366366
dit_config["in_channels"] = 16
367367
dit_config["dim"] = 2304
368-
dit_config["cap_feat_dim"] = 2304
369-
dit_config["n_layers"] = 26
368+
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1]
369+
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
370370
dit_config["n_heads"] = 24
371371
dit_config["n_kv_heads"] = 8
372372
dit_config["qk_norm"] = True

comfy/sd.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,7 @@ class TEModel(Enum):
890890
QWEN25_3B = 10
891891
QWEN25_7B = 11
892892
BYT5_SMALL_GLYPH = 12
893+
GEMMA_3_4B = 13
893894

894895
def detect_te_model(sd):
895896
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -912,6 +913,8 @@ def detect_te_model(sd):
912913
return TEModel.BYT5_SMALL_GLYPH
913914
return TEModel.T5_BASE
914915
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
916+
if 'model.layers.0.self_attn.q_norm.weight' in sd:
917+
return TEModel.GEMMA_3_4B
915918
return TEModel.GEMMA_2_2B
916919
if 'model.layers.0.self_attn.k_proj.bias' in sd:
917920
weight = sd['model.layers.0.self_attn.k_proj.bias']
@@ -1016,6 +1019,10 @@ class EmptyClass:
10161019
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
10171020
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
10181021
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
1022+
elif te_model == TEModel.GEMMA_3_4B:
1023+
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
1024+
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
1025+
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
10191026
elif te_model == TEModel.LLAMA3_8:
10201027
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
10211028
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)

comfy/text_encoders/llama.py

Lines changed: 111 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from dataclasses import dataclass
44
from typing import Optional, Any
55
import math
6+
import logging
67

78
from comfy.ldm.modules.attention import optimized_attention_for_device
89
import comfy.model_management
@@ -28,6 +29,9 @@ class Llama2Config:
2829
mlp_activation = "silu"
2930
qkv_bias = False
3031
rope_dims = None
32+
q_norm = None
33+
k_norm = None
34+
rope_scale = None
3135

3236
@dataclass
3337
class Qwen25_3BConfig:
@@ -46,6 +50,9 @@ class Qwen25_3BConfig:
4650
mlp_activation = "silu"
4751
qkv_bias = True
4852
rope_dims = None
53+
q_norm = None
54+
k_norm = None
55+
rope_scale = None
4956

5057
@dataclass
5158
class Qwen25_7BVLI_Config:
@@ -64,6 +71,9 @@ class Qwen25_7BVLI_Config:
6471
mlp_activation = "silu"
6572
qkv_bias = True
6673
rope_dims = [16, 24, 24]
74+
q_norm = None
75+
k_norm = None
76+
rope_scale = None
6777

6878
@dataclass
6979
class Gemma2_2B_Config:
@@ -82,6 +92,32 @@ class Gemma2_2B_Config:
8292
mlp_activation = "gelu_pytorch_tanh"
8393
qkv_bias = False
8494
rope_dims = None
95+
q_norm = None
96+
k_norm = None
97+
sliding_attention = None
98+
rope_scale = None
99+
100+
@dataclass
101+
class Gemma3_4B_Config:
102+
vocab_size: int = 262208
103+
hidden_size: int = 2560
104+
intermediate_size: int = 10240
105+
num_hidden_layers: int = 34
106+
num_attention_heads: int = 8
107+
num_key_value_heads: int = 4
108+
max_position_embeddings: int = 131072
109+
rms_norm_eps: float = 1e-6
110+
rope_theta = [10000.0, 1000000.0]
111+
transformer_type: str = "gemma3"
112+
head_dim = 256
113+
rms_norm_add = True
114+
mlp_activation = "gelu_pytorch_tanh"
115+
qkv_bias = False
116+
rope_dims = None
117+
q_norm = "gemma3"
118+
k_norm = "gemma3"
119+
sliding_attention = [False, False, False, False, False, 1024]
120+
rope_scale = [1.0, 8.0]
85121

86122
class RMSNorm(nn.Module):
87123
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -106,25 +142,40 @@ def rotate_half(x):
106142
return torch.cat((-x2, x1), dim=-1)
107143

108144

109-
def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None):
110-
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
111-
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
145+
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
146+
if not isinstance(theta, list):
147+
theta = [theta]
148+
149+
out = []
150+
for index, t in enumerate(theta):
151+
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
152+
inv_freq = 1.0 / (t ** (theta_numerator / head_dim))
153+
154+
if rope_scale is not None:
155+
if isinstance(rope_scale, list):
156+
inv_freq /= rope_scale[index]
157+
else:
158+
inv_freq /= rope_scale
159+
160+
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
161+
position_ids_expanded = position_ids[:, None, :].float()
162+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
163+
emb = torch.cat((freqs, freqs), dim=-1)
164+
cos = emb.cos()
165+
sin = emb.sin()
166+
if rope_dims is not None and position_ids.shape[0] > 1:
167+
mrope_section = rope_dims * 2
168+
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
169+
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
170+
else:
171+
cos = cos.unsqueeze(1)
172+
sin = sin.unsqueeze(1)
173+
out.append((cos, sin))
112174

113-
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
114-
position_ids_expanded = position_ids[:, None, :].float()
115-
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
116-
emb = torch.cat((freqs, freqs), dim=-1)
117-
cos = emb.cos()
118-
sin = emb.sin()
119-
if rope_dims is not None and position_ids.shape[0] > 1:
120-
mrope_section = rope_dims * 2
121-
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
122-
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
123-
else:
124-
cos = cos.unsqueeze(1)
125-
sin = sin.unsqueeze(1)
175+
if len(out) == 1:
176+
return out[0]
126177

127-
return (cos, sin)
178+
return out
128179

129180

130181
def apply_rope(xq, xk, freqs_cis):
@@ -152,6 +203,14 @@ def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = Non
152203
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
153204
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
154205

206+
self.q_norm = None
207+
self.k_norm = None
208+
209+
if config.q_norm == "gemma3":
210+
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
211+
if config.k_norm == "gemma3":
212+
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
213+
155214
def forward(
156215
self,
157216
hidden_states: torch.Tensor,
@@ -168,6 +227,11 @@ def forward(
168227
xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
169228
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
170229

230+
if self.q_norm is not None:
231+
xq = self.q_norm(xq)
232+
if self.k_norm is not None:
233+
xk = self.k_norm(xk)
234+
171235
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
172236

173237
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
@@ -192,7 +256,7 @@ def forward(self, x):
192256
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
193257

194258
class TransformerBlock(nn.Module):
195-
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
259+
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
196260
super().__init__()
197261
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
198262
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
@@ -226,7 +290,7 @@ def forward(
226290
return x
227291

228292
class TransformerBlockGemma2(nn.Module):
229-
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
293+
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
230294
super().__init__()
231295
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
232296
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
@@ -235,13 +299,28 @@ def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = Non
235299
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
236300
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
237301

302+
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
303+
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
304+
else:
305+
self.sliding_attention = False
306+
307+
self.transformer_type = config.transformer_type
308+
238309
def forward(
239310
self,
240311
x: torch.Tensor,
241312
attention_mask: Optional[torch.Tensor] = None,
242313
freqs_cis: Optional[torch.Tensor] = None,
243314
optimized_attention=None,
244315
):
316+
if self.transformer_type == 'gemma3':
317+
if self.sliding_attention:
318+
if x.shape[1] > self.sliding_attention:
319+
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
320+
freqs_cis = freqs_cis[1]
321+
else:
322+
freqs_cis = freqs_cis[0]
323+
245324
# Self Attention
246325
residual = x
247326
x = self.input_layernorm(x)
@@ -276,16 +355,16 @@ def __init__(self, config, device=None, dtype=None, ops=None):
276355
device=device,
277356
dtype=dtype
278357
)
279-
if self.config.transformer_type == "gemma2":
358+
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
280359
transformer = TransformerBlockGemma2
281360
self.normalize_in = True
282361
else:
283362
transformer = TransformerBlock
284363
self.normalize_in = False
285364

286365
self.layers = nn.ModuleList([
287-
transformer(config, device=device, dtype=dtype, ops=ops)
288-
for _ in range(config.num_hidden_layers)
366+
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
367+
for i in range(config.num_hidden_layers)
289368
])
290369
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
291370
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
@@ -305,6 +384,7 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
305384
freqs_cis = precompute_freqs_cis(self.config.head_dim,
306385
position_ids,
307386
self.config.rope_theta,
387+
self.config.rope_scale,
308388
self.config.rope_dims,
309389
device=x.device)
310390

@@ -433,3 +513,12 @@ def __init__(self, config_dict, dtype, device, operations):
433513

434514
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
435515
self.dtype = dtype
516+
517+
class Gemma3_4B(BaseLlama, torch.nn.Module):
518+
def __init__(self, config_dict, dtype, device, operations):
519+
super().__init__()
520+
config = Gemma3_4B_Config(**config_dict)
521+
self.num_layers = config.num_hidden_layers
522+
523+
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
524+
self.dtype = dtype

comfy/text_encoders/lumina2.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,47 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
1111
def state_dict(self):
1212
return {"spiece_model": self.tokenizer.serialize_model()}
1313

14+
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
15+
def __init__(self, embedding_directory=None, tokenizer_data={}):
16+
tokenizer = tokenizer_data.get("spiece_model", None)
17+
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
18+
19+
def state_dict(self):
20+
return {"spiece_model": self.tokenizer.serialize_model()}
1421

1522
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
1623
def __init__(self, embedding_directory=None, tokenizer_data={}):
1724
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
1825

26+
class NTokenizer(sd1_clip.SD1Tokenizer):
27+
def __init__(self, embedding_directory=None, tokenizer_data={}):
28+
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_4b", tokenizer=Gemma3_4BTokenizer)
1929

2030
class Gemma2_2BModel(sd1_clip.SDClipModel):
2131
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
2232
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
2333

34+
class Gemma3_4BModel(sd1_clip.SDClipModel):
35+
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
36+
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
2437

2538
class LuminaModel(sd1_clip.SD1ClipModel):
26-
def __init__(self, device="cpu", dtype=None, model_options={}):
27-
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
39+
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
40+
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
41+
2842

43+
def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
44+
if model_type == "gemma2_2b":
45+
model = Gemma2_2BModel
46+
elif model_type == "gemma3_4b":
47+
model = Gemma3_4BModel
2948

30-
def te(dtype_llama=None, llama_scaled_fp8=None):
3149
class LuminaTEModel_(LuminaModel):
3250
def __init__(self, device="cpu", dtype=None, model_options={}):
3351
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
3452
model_options = model_options.copy()
3553
model_options["scaled_fp8"] = llama_scaled_fp8
3654
if dtype_llama is not None:
3755
dtype = dtype_llama
38-
super().__init__(device=device, dtype=dtype, model_options=model_options)
56+
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
3957
return LuminaTEModel_

0 commit comments

Comments
 (0)