3
3
from dataclasses import dataclass
4
4
from typing import Optional , Any
5
5
import math
6
+ import logging
6
7
7
8
from comfy .ldm .modules .attention import optimized_attention_for_device
8
9
import comfy .model_management
@@ -28,6 +29,9 @@ class Llama2Config:
28
29
mlp_activation = "silu"
29
30
qkv_bias = False
30
31
rope_dims = None
32
+ q_norm = None
33
+ k_norm = None
34
+ rope_scale = None
31
35
32
36
@dataclass
33
37
class Qwen25_3BConfig :
@@ -46,6 +50,9 @@ class Qwen25_3BConfig:
46
50
mlp_activation = "silu"
47
51
qkv_bias = True
48
52
rope_dims = None
53
+ q_norm = None
54
+ k_norm = None
55
+ rope_scale = None
49
56
50
57
@dataclass
51
58
class Qwen25_7BVLI_Config :
@@ -64,6 +71,9 @@ class Qwen25_7BVLI_Config:
64
71
mlp_activation = "silu"
65
72
qkv_bias = True
66
73
rope_dims = [16 , 24 , 24 ]
74
+ q_norm = None
75
+ k_norm = None
76
+ rope_scale = None
67
77
68
78
@dataclass
69
79
class Gemma2_2B_Config :
@@ -82,6 +92,32 @@ class Gemma2_2B_Config:
82
92
mlp_activation = "gelu_pytorch_tanh"
83
93
qkv_bias = False
84
94
rope_dims = None
95
+ q_norm = None
96
+ k_norm = None
97
+ sliding_attention = None
98
+ rope_scale = None
99
+
100
+ @dataclass
101
+ class Gemma3_4B_Config :
102
+ vocab_size : int = 262208
103
+ hidden_size : int = 2560
104
+ intermediate_size : int = 10240
105
+ num_hidden_layers : int = 34
106
+ num_attention_heads : int = 8
107
+ num_key_value_heads : int = 4
108
+ max_position_embeddings : int = 131072
109
+ rms_norm_eps : float = 1e-6
110
+ rope_theta = [10000.0 , 1000000.0 ]
111
+ transformer_type : str = "gemma3"
112
+ head_dim = 256
113
+ rms_norm_add = True
114
+ mlp_activation = "gelu_pytorch_tanh"
115
+ qkv_bias = False
116
+ rope_dims = None
117
+ q_norm = "gemma3"
118
+ k_norm = "gemma3"
119
+ sliding_attention = [False , False , False , False , False , 1024 ]
120
+ rope_scale = [1.0 , 8.0 ]
85
121
86
122
class RMSNorm (nn .Module ):
87
123
def __init__ (self , dim : int , eps : float = 1e-5 , add = False , device = None , dtype = None ):
@@ -106,25 +142,40 @@ def rotate_half(x):
106
142
return torch .cat ((- x2 , x1 ), dim = - 1 )
107
143
108
144
109
- def precompute_freqs_cis (head_dim , position_ids , theta , rope_dims = None , device = None ):
110
- theta_numerator = torch .arange (0 , head_dim , 2 , device = device ).float ()
111
- inv_freq = 1.0 / (theta ** (theta_numerator / head_dim ))
145
+ def precompute_freqs_cis (head_dim , position_ids , theta , rope_scale = None , rope_dims = None , device = None ):
146
+ if not isinstance (theta , list ):
147
+ theta = [theta ]
148
+
149
+ out = []
150
+ for index , t in enumerate (theta ):
151
+ theta_numerator = torch .arange (0 , head_dim , 2 , device = device ).float ()
152
+ inv_freq = 1.0 / (t ** (theta_numerator / head_dim ))
153
+
154
+ if rope_scale is not None :
155
+ if isinstance (rope_scale , list ):
156
+ inv_freq /= rope_scale [index ]
157
+ else :
158
+ inv_freq /= rope_scale
159
+
160
+ inv_freq_expanded = inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 )
161
+ position_ids_expanded = position_ids [:, None , :].float ()
162
+ freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
163
+ emb = torch .cat ((freqs , freqs ), dim = - 1 )
164
+ cos = emb .cos ()
165
+ sin = emb .sin ()
166
+ if rope_dims is not None and position_ids .shape [0 ] > 1 :
167
+ mrope_section = rope_dims * 2
168
+ cos = torch .cat ([m [i % 3 ] for i , m in enumerate (cos .split (mrope_section , dim = - 1 ))], dim = - 1 ).unsqueeze (0 )
169
+ sin = torch .cat ([m [i % 3 ] for i , m in enumerate (sin .split (mrope_section , dim = - 1 ))], dim = - 1 ).unsqueeze (0 )
170
+ else :
171
+ cos = cos .unsqueeze (1 )
172
+ sin = sin .unsqueeze (1 )
173
+ out .append ((cos , sin ))
112
174
113
- inv_freq_expanded = inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 )
114
- position_ids_expanded = position_ids [:, None , :].float ()
115
- freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
116
- emb = torch .cat ((freqs , freqs ), dim = - 1 )
117
- cos = emb .cos ()
118
- sin = emb .sin ()
119
- if rope_dims is not None and position_ids .shape [0 ] > 1 :
120
- mrope_section = rope_dims * 2
121
- cos = torch .cat ([m [i % 3 ] for i , m in enumerate (cos .split (mrope_section , dim = - 1 ))], dim = - 1 ).unsqueeze (0 )
122
- sin = torch .cat ([m [i % 3 ] for i , m in enumerate (sin .split (mrope_section , dim = - 1 ))], dim = - 1 ).unsqueeze (0 )
123
- else :
124
- cos = cos .unsqueeze (1 )
125
- sin = sin .unsqueeze (1 )
175
+ if len (out ) == 1 :
176
+ return out [0 ]
126
177
127
- return ( cos , sin )
178
+ return out
128
179
129
180
130
181
def apply_rope (xq , xk , freqs_cis ):
@@ -152,6 +203,14 @@ def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = Non
152
203
self .v_proj = ops .Linear (config .hidden_size , self .num_kv_heads * self .head_dim , bias = config .qkv_bias , device = device , dtype = dtype )
153
204
self .o_proj = ops .Linear (self .inner_size , config .hidden_size , bias = False , device = device , dtype = dtype )
154
205
206
+ self .q_norm = None
207
+ self .k_norm = None
208
+
209
+ if config .q_norm == "gemma3" :
210
+ self .q_norm = RMSNorm (self .head_dim , eps = config .rms_norm_eps , add = config .rms_norm_add , device = device , dtype = dtype )
211
+ if config .k_norm == "gemma3" :
212
+ self .k_norm = RMSNorm (self .head_dim , eps = config .rms_norm_eps , add = config .rms_norm_add , device = device , dtype = dtype )
213
+
155
214
def forward (
156
215
self ,
157
216
hidden_states : torch .Tensor ,
@@ -168,6 +227,11 @@ def forward(
168
227
xk = xk .view (batch_size , seq_length , self .num_kv_heads , self .head_dim ).transpose (1 , 2 )
169
228
xv = xv .view (batch_size , seq_length , self .num_kv_heads , self .head_dim ).transpose (1 , 2 )
170
229
230
+ if self .q_norm is not None :
231
+ xq = self .q_norm (xq )
232
+ if self .k_norm is not None :
233
+ xk = self .k_norm (xk )
234
+
171
235
xq , xk = apply_rope (xq , xk , freqs_cis = freqs_cis )
172
236
173
237
xk = xk .repeat_interleave (self .num_heads // self .num_kv_heads , dim = 1 )
@@ -192,7 +256,7 @@ def forward(self, x):
192
256
return self .down_proj (self .activation (self .gate_proj (x )) * self .up_proj (x ))
193
257
194
258
class TransformerBlock (nn .Module ):
195
- def __init__ (self , config : Llama2Config , device = None , dtype = None , ops : Any = None ):
259
+ def __init__ (self , config : Llama2Config , index , device = None , dtype = None , ops : Any = None ):
196
260
super ().__init__ ()
197
261
self .self_attn = Attention (config , device = device , dtype = dtype , ops = ops )
198
262
self .mlp = MLP (config , device = device , dtype = dtype , ops = ops )
@@ -226,7 +290,7 @@ def forward(
226
290
return x
227
291
228
292
class TransformerBlockGemma2 (nn .Module ):
229
- def __init__ (self , config : Llama2Config , device = None , dtype = None , ops : Any = None ):
293
+ def __init__ (self , config : Llama2Config , index , device = None , dtype = None , ops : Any = None ):
230
294
super ().__init__ ()
231
295
self .self_attn = Attention (config , device = device , dtype = dtype , ops = ops )
232
296
self .mlp = MLP (config , device = device , dtype = dtype , ops = ops )
@@ -235,13 +299,28 @@ def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = Non
235
299
self .pre_feedforward_layernorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps , add = config .rms_norm_add , device = device , dtype = dtype )
236
300
self .post_feedforward_layernorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps , add = config .rms_norm_add , device = device , dtype = dtype )
237
301
302
+ if config .sliding_attention is not None : # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
303
+ self .sliding_attention = config .sliding_attention [index % len (config .sliding_attention )]
304
+ else :
305
+ self .sliding_attention = False
306
+
307
+ self .transformer_type = config .transformer_type
308
+
238
309
def forward (
239
310
self ,
240
311
x : torch .Tensor ,
241
312
attention_mask : Optional [torch .Tensor ] = None ,
242
313
freqs_cis : Optional [torch .Tensor ] = None ,
243
314
optimized_attention = None ,
244
315
):
316
+ if self .transformer_type == 'gemma3' :
317
+ if self .sliding_attention :
318
+ if x .shape [1 ] > self .sliding_attention :
319
+ logging .warning ("Warning: sliding attention not implemented, results may be incorrect" )
320
+ freqs_cis = freqs_cis [1 ]
321
+ else :
322
+ freqs_cis = freqs_cis [0 ]
323
+
245
324
# Self Attention
246
325
residual = x
247
326
x = self .input_layernorm (x )
@@ -276,16 +355,16 @@ def __init__(self, config, device=None, dtype=None, ops=None):
276
355
device = device ,
277
356
dtype = dtype
278
357
)
279
- if self .config .transformer_type == "gemma2" :
358
+ if self .config .transformer_type == "gemma2" or self . config . transformer_type == "gemma3" :
280
359
transformer = TransformerBlockGemma2
281
360
self .normalize_in = True
282
361
else :
283
362
transformer = TransformerBlock
284
363
self .normalize_in = False
285
364
286
365
self .layers = nn .ModuleList ([
287
- transformer (config , device = device , dtype = dtype , ops = ops )
288
- for _ in range (config .num_hidden_layers )
366
+ transformer (config , index = i , device = device , dtype = dtype , ops = ops )
367
+ for i in range (config .num_hidden_layers )
289
368
])
290
369
self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps , add = config .rms_norm_add , device = device , dtype = dtype )
291
370
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
@@ -305,6 +384,7 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
305
384
freqs_cis = precompute_freqs_cis (self .config .head_dim ,
306
385
position_ids ,
307
386
self .config .rope_theta ,
387
+ self .config .rope_scale ,
308
388
self .config .rope_dims ,
309
389
device = x .device )
310
390
@@ -433,3 +513,12 @@ def __init__(self, config_dict, dtype, device, operations):
433
513
434
514
self .model = Llama2_ (config , device = device , dtype = dtype , ops = operations )
435
515
self .dtype = dtype
516
+
517
+ class Gemma3_4B (BaseLlama , torch .nn .Module ):
518
+ def __init__ (self , config_dict , dtype , device , operations ):
519
+ super ().__init__ ()
520
+ config = Gemma3_4B_Config (** config_dict )
521
+ self .num_layers = config .num_hidden_layers
522
+
523
+ self .model = Llama2_ (config , device = device , dtype = dtype , ops = operations )
524
+ self .dtype = dtype
0 commit comments