@@ -76,6 +76,22 @@ def __init__(self, decoder_config, running_config=None, with_cross_attn=False):
7676 running_config = running_config ,
7777 )
7878
79+ self .image_generation = getattr (running_config , "image_generation" , False )
80+
81+ if self .image_generation :
82+ # initialize MOE GEN params
83+ self .input_layernorm_moe_gen = LayerNorm [decoder_config .layer_norm ](
84+ decoder_config .hidden_size , eps = decoder_config .norm_eps
85+ )
86+ if decoder_config .post_attention_layernorm :
87+ self .post_attention_layernorm_moe_gen = LayerNorm [decoder_config .layer_norm ](
88+ decoder_config .hidden_size , eps = decoder_config .norm_eps
89+ )
90+ self .mlp_moe_gen = MLP (
91+ decoder_config ,
92+ running_config = running_config ,
93+ )
94+
7995 def _mlp (self , hidden_states ):
8096 if self .ffn_layernorm :
8197 hidden_states = self .pre_feedforward_layernorm (hidden_states )
@@ -110,17 +126,34 @@ def forward(self, layer_in, **kwargs):
110126 return_attn = kwargs .pop ("return_attn" , False )
111127 position_embeddings = kwargs .pop ("position_embeddings" , None )
112128
129+ text_indices = kwargs .pop ("text_indices" , None )
130+ image_indices = kwargs .pop ("image_indices" , None )
113131
114- norm_layer_in = self .input_layernorm (layer_in )
132+ if self .image_generation :
133+ assert text_indices is not None , "Text indices must be provided for image generation"
134+ assert image_indices is not None , "Image indices must be provided for image generation"
135+ norm_layer_in = torch .zeros_like (layer_in , dtype = layer_in .dtype , device = layer_in .device )
136+ norm_layer_in [:, text_indices , :] = self .input_layernorm (layer_in [:, text_indices , :])
137+ norm_layer_in [:, image_indices , :] = self .input_layernorm_moe_gen (layer_in [:, image_indices , :])
138+ else :
139+ norm_layer_in = self .input_layernorm (layer_in )
140+
141+ print ("NORM_LAYER_IN:" , norm_layer_in .shape , norm_layer_in .sum (), norm_layer_in )
142+ print ("NORM_LAYER_IN img:" , norm_layer_in [:, - 4098 :, :].shape , norm_layer_in [:, - 4098 :, :].sum (), norm_layer_in [:, - 4098 :, :])
115143
116144 self_attn , attns = self .self_attn (
117145 norm_layer_in ,
118146 attn_mask = attn_mask ,
119147 step = step ,
120148 return_attn = return_attn ,
121149 position_embeddings = position_embeddings ,
150+ text_indices = text_indices ,
151+ image_indices = image_indices ,
122152 )
123153
154+ # print("SELF_ATTN:", self_attn.shape, self_attn.sum(), self_attn)
155+ # print("SELF_ATTN img:", self_attn[:, -4098:, :].shape, self_attn[:, -4098:, :].sum(), self_attn[:, -4098:, :])
156+
124157 if self .dropout_p > 0 :
125158 self_attn = self .dropout (self_attn )
126159
@@ -130,6 +163,8 @@ def forward(self, layer_in, **kwargs):
130163 layer_out = ff_in + self ._mlp (ff_in )
131164 return layer_out , attns
132165
166+ text_sequence , image_sequence = None , None # dirty patch
167+
133168 if self .parallel_residual :
134169 if self .context_attn :
135170 ctx_attn , attns = self .context_attn (
@@ -160,9 +195,27 @@ def forward(self, layer_in, **kwargs):
160195 ctx_attn = self .dropout (ctx_attn )
161196 else :
162197 ctx_attn = 0
163- ff_in = self .post_attention_layernorm (ctx_attn + self_attn + layer_in )
198+ sequence = ctx_attn + self_attn + layer_in
199+ if self .image_generation :
200+ text_sequence = sequence [:, text_indices , :]
201+ image_sequence = sequence [:, image_indices , :]
202+ text_sequence = self .post_attention_layernorm (text_sequence )
203+ image_sequence = self .post_attention_layernorm_moe_gen (image_sequence )
204+ # print("POST_ATTENTION_LAYER_NORM text:", text_sequence.shape, text_sequence.sum(), text_sequence)
205+ # print("POST_ATTENTION_LAYER_NORM img:", image_sequence.shape, image_sequence.sum(), image_sequence)
206+ else :
207+ ff_in = self .post_attention_layernorm (sequence )
164208 # we apply residual with un-normed
165- MLP = self .mlp (ff_in )
209+ if self .image_generation :
210+ MLP = torch .zeros_like (sequence , dtype = sequence .dtype , device = sequence .device )
211+ MLP [:, text_indices , :] = self .mlp (text_sequence )
212+ MLP [:, image_indices , :] = self .mlp_moe_gen (image_sequence )
213+ # print("MLP text:", MLP[:, text_indices, :].shape, MLP[:, text_indices, :].sum(), MLP[:, text_indices, :])
214+ # print("MLP img:", MLP[:, image_indices, :].shape, MLP[:, image_indices, :].sum(), MLP[:, image_indices, :])
215+ else :
216+ MLP = self .mlp (ff_in )
217+ # print("MLP:", MLP.shape, MLP.sum(), MLP)
218+ # print("MLP img:", MLP[:, -4098:, :].shape, MLP[:, -4098:, :].sum(), MLP[:, -4098:, :])
166219 layer_out = MLP + layer_in + self_attn + ctx_attn
167220
168221 return layer_out , attns
@@ -227,7 +280,13 @@ def __init__(
227280 for i in range (decoder_config .layers )
228281 ]
229282 )
283+ self .image_generation = getattr (running_config , "image_generation" , False )
230284 self .layer_norm = LayerNorm [decoder_config .layer_norm ](decoder_config .hidden_size , eps = decoder_config .norm_eps )
285+ if self .image_generation :
286+ # initialize MOE GEN params
287+ self .layer_norm_moe_gen = LayerNorm [decoder_config .layer_norm ](
288+ decoder_config .hidden_size , eps = decoder_config .norm_eps
289+ )
231290 self ._disable_cache ()
232291
233292 @classmethod
@@ -268,6 +327,8 @@ def _causal_attn_mask(self, tgt_pad_mask):
268327 )
269328 if self .sliding_window > 0 :
270329 future_mask = future_mask .triu_ (- self .sliding_window )
330+ # print("future_mask", future_mask.shape, future_mask.dtype, future_mask.device)
331+ # print("tgt_pad_mask", tgt_pad_mask.shape, tgt_pad_mask.dtype, tgt_pad_mask.device)
271332 attn_mask = ~ tgt_pad_mask & future_mask .unsqueeze (0 )
272333 return attn_mask .unsqueeze (1 ) # (batch x 1 x 1 x tgt_len)
273334
@@ -314,7 +375,11 @@ def forward(self, emb, **kwargs):
314375 with_align = kwargs .pop ("with_align" , False )
315376 return_attn = with_align or kwargs .pop ("return_attn" , False )
316377 positions = kwargs .pop ("positions" , None )
378+ # print("positions", positions)
317379 position_embeddings = self .rope .update (emb .size (1 ), step = step , positions = positions )
380+ cos , sin = position_embeddings
381+ # print("COS:", cos.shape, cos.sum(), cos)
382+ # print("SIN:", sin.shape, sin.sum(), sin)
318383 if self .rope_local is not None :
319384 position_embeddings_local = self .rope_local .update (emb .size (1 ), step = step )
320385 else :
@@ -341,12 +406,19 @@ def forward(self, emb, **kwargs):
341406
342407 # we need to adapt the mask for gemma3, TODO: find another condition?
343408 # SEEMS OK TO MASK IMAGES FOR LLAVA TOO ?
409+ # print("ATTN_MASK before update", attn_mask.shape, attn_mask)
344410 if decoder_in is not None and attn_mask is not None :
411+ # print("DECODER_IN:", decoder_in)
345412 attn_mask = self ._update_causal_mask (attn_mask , (decoder_in == image_token_id ) | (decoder_in == 151652 ) | (decoder_in == 151653 ))
413+ # print("ATTN_MASK after update", attn_mask.shape, attn_mask)
414+
415+
416+
346417 if self .sliding_window > 0 and step >= self .sliding_window and attn_mask is not None :
347418 attn_mask = attn_mask [:, :, :, - self .sliding_window :]
348419
349420 for i , layer in enumerate (self .transformer_layers ):
421+ print (f"\n =================\n LAYER { i } \n =================\n " )
350422 emb , attn = layer (
351423 emb ,
352424 enc_out = enc_out if enc_out is not None else emb ,
@@ -357,7 +429,9 @@ def forward(self, emb, **kwargs):
357429 position_embeddings = (
358430 position_embeddings_local if (i + 1 ) % self .interleave_local else position_embeddings
359431 ),
432+ ** kwargs ,
360433 )
434+ print ("EMB:" , emb .shape , emb .sum (), emb )
361435 if with_align :
362436 attn_align = layer .get_attn_align (
363437 emb ,
@@ -374,7 +448,19 @@ def forward(self, emb, **kwargs):
374448 if attn_align is not None :
375449 attn_aligns .append (attn_align )
376450
377- emb = self .layer_norm (emb )
451+
452+ # TODO apply MOE logic here
453+ if self .image_generation :
454+ emb_ = torch .zeros_like (emb , dtype = emb .dtype , device = emb .device )
455+ text_indices = kwargs .get ("text_indices" , None )
456+ image_indices = kwargs .get ("image_indices" , None )
457+ assert text_indices is not None , "Text indices must be provided for image generation"
458+ assert image_indices is not None , "Image indices must be provided for image generation"
459+ emb_ [:, text_indices , :] = self .layer_norm (emb [:, text_indices , :])
460+ emb_ [:, image_indices , :] = self .layer_norm_moe_gen (emb [:, image_indices , :])
461+ emb = emb_
462+ else :
463+ emb = self .layer_norm (emb )
378464
379465 # we take the first head
380466 top_attn = None if attn is None else attn [:, 0 , :, :].contiguous ()
0 commit comments