Skip to content

Commit 0d99a39

Browse files
WIP super dirty but image generation WORKS (mostly)
1 parent 465db9b commit 0d99a39

File tree

12 files changed

+677
-25
lines changed

12 files changed

+677
-25
lines changed

eole/bin/convert/HF_mappings.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"Qwen2ForCausalLM": { # for bagel, but we need to add some conditions to keep supporting real qwen2...
4141
"decoder_layer_prefix": "language_model.model.layers.",
4242
"decoder.layer_norm.weight": "language_model.model.norm.weight",
43+
"decoder.layer_norm_moe_gen.weight": "language_model.model.norm_moe_gen.weight",
4344
"encoder_layer_prefix": "vit_model.vision_model.encoder.layers.",
4445
"encoder.patch_conv.weight": "vit_model.vision_model.embeddings.patch_embedding.weight",
4546
"encoder.patch_conv.bias": "vit_model.vision_model.embeddings.patch_embedding.bias",
@@ -53,10 +54,33 @@
5354
"adapter.w_in.bias": "connector.fc1.bias",
5455
"adapter.w_out.weight": "connector.fc2.weight",
5556
"adapter.w_out.bias": "connector.fc2.bias",
57+
# additional stuff, mostly replicated as-is for now
5658
"vit_pos_embed.pos_embed": "vit_pos_embed.pos_embed",
59+
"latent_pos_embed.pos_embed": "latent_pos_embed.pos_embed",
60+
"time_embedder.mlp.0.weight": "time_embedder.mlp.0.weight",
61+
"time_embedder.mlp.0.bias": "time_embedder.mlp.0.bias",
62+
"time_embedder.mlp.2.weight": "time_embedder.mlp.2.weight",
63+
"time_embedder.mlp.2.bias": "time_embedder.mlp.2.bias",
64+
"vae2llm.weight": "vae2llm.weight",
65+
"vae2llm.bias": "vae2llm.bias",
66+
"llm2vae.weight": "llm2vae.weight",
67+
"llm2vae.bias": "llm2vae.bias",
68+
# TODO: not sure how to properly grab VAE stuff
5769
"decoder": {
5870
".self_attn.q_norm.": ".self_attn.q_norm.",
5971
".self_attn.k_norm.": ".self_attn.k_norm.",
72+
# MOE GEN (simplify with loop?)
73+
".self_attn.q_norm_moe_gen.": ".self_attn.q_norm_moe_gen.",
74+
".self_attn.k_norm_moe_gen.": ".self_attn.k_norm_moe_gen.",
75+
".self_attn.linear_query_moe_gen.": ".self_attn.q_proj_moe_gen.",
76+
".self_attn.linear_keys_moe_gen.": ".self_attn.k_proj_moe_gen.",
77+
".self_attn.linear_values_moe_gen.": ".self_attn.v_proj_moe_gen.",
78+
".self_attn.final_linear_moe_gen.": ".self_attn.o_proj_moe_gen.",
79+
".mlp_moe_gen.gate_up_proj.": ".mlp_moe_gen.gate_proj.",
80+
".mlp_moe_gen.down_proj.": ".mlp_moe_gen.down_proj.",
81+
".mlp_moe_gen.up_proj.": ".mlp_moe_gen.up_proj.",
82+
".input_layernorm_moe_gen.": ".input_layernorm_moe_gen.",
83+
".post_attention_layernorm_moe_gen.": ".post_attention_layernorm_moe_gen.",
6084
},
6185
"encoder": {
6286
".self_attn.linear_query.": ".self_attn.q_proj.",

eole/bin/convert/convert_HF.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class HuggingfaceFiles:
6767
model_path: Optional[str] = None
6868
special_tokens_json: Optional[str] = None
6969
vision_config_path: Optional[str] = None
70+
ae_model_path: Optional[str] = None
7071

7172
# Unified dictionary to cache loaded files
7273
_loaded_files: dict = field(default_factory=dict, init=False)
@@ -130,6 +131,7 @@ def download_file_from_hub(file_name, required=True):
130131
or get_file_fn("pytorch_model.bin", required=False) or get_file_fn("ema.safetensors", required=False),
131132
"special_tokens_json": get_file_fn("special_tokens_map.json", required=False),
132133
"vision_config_path": get_file_fn("vit_config.json", required=False),
134+
"ae_model_path": get_file_fn("ae.safetensors", required=False),
133135
}
134136

135137
return cls(**paths, model_dir=args.model_dir, token=args.token)
@@ -677,6 +679,13 @@ def build_shards(model_config, hf, args, params):
677679
eole_safetensor = {}
678680

679681
def build_first_shard(hf, eole_safetensor):
682+
# let's add AE here
683+
if hf.ae_model_path is not None:
684+
ae_checkpoint = hf.get_load_ckpt(*os.path.split(hf.ae_model_path))
685+
ae_params = safetensors.torch.load_file(ae_checkpoint)
686+
for key, value in ae_params.items():
687+
eole_safetensor[f"image_autoencoder.{key}"] = value
688+
680689
for target in KEY_MAPS[hf.arch].keys():
681690
if model_config["share_decoder_embeddings"] and target == "generator.weight":
682691
continue

eole/config/inference.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,23 @@ class InferenceConfig(RunningConfig, DecodingConfig, LoRaConfig, QuantizeConfig)
111111
description="Optional EOS tokens that would stop generation, e.g. <|eot_id|> for Llama3",
112112
)
113113

114+
# image generation specific stuff, might move elsewhere
115+
image_generation: bool | None = Field(
116+
default=False,
117+
description="Generate image from text input. "
118+
"This will only work if the model is trained for image generation.",
119+
)
120+
image_width: int | None = Field(
121+
default=1024,
122+
description="Width of the generated image. "
123+
"This will only work if the model is trained for image generation.",
124+
)
125+
image_height: int | None = Field(
126+
default=1024,
127+
description="Height of the generated image. "
128+
"This will only work if the model is trained for image generation.",
129+
)
130+
114131
def get_model_path(self):
115132
return self.model_path[0]
116133

eole/decoders/transformer.py

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,22 @@ def __init__(self, decoder_config, running_config=None, with_cross_attn=False):
7676
running_config=running_config,
7777
)
7878

79+
self.image_generation = getattr(running_config, "image_generation", False)
80+
81+
if self.image_generation:
82+
# initialize MOE GEN params
83+
self.input_layernorm_moe_gen = LayerNorm[decoder_config.layer_norm](
84+
decoder_config.hidden_size, eps=decoder_config.norm_eps
85+
)
86+
if decoder_config.post_attention_layernorm:
87+
self.post_attention_layernorm_moe_gen = LayerNorm[decoder_config.layer_norm](
88+
decoder_config.hidden_size, eps=decoder_config.norm_eps
89+
)
90+
self.mlp_moe_gen = MLP(
91+
decoder_config,
92+
running_config=running_config,
93+
)
94+
7995
def _mlp(self, hidden_states):
8096
if self.ffn_layernorm:
8197
hidden_states = self.pre_feedforward_layernorm(hidden_states)
@@ -110,17 +126,34 @@ def forward(self, layer_in, **kwargs):
110126
return_attn = kwargs.pop("return_attn", False)
111127
position_embeddings = kwargs.pop("position_embeddings", None)
112128

129+
text_indices = kwargs.pop("text_indices", None)
130+
image_indices = kwargs.pop("image_indices", None)
113131

114-
norm_layer_in = self.input_layernorm(layer_in)
132+
if self.image_generation:
133+
assert text_indices is not None, "Text indices must be provided for image generation"
134+
assert image_indices is not None, "Image indices must be provided for image generation"
135+
norm_layer_in = torch.zeros_like(layer_in, dtype=layer_in.dtype, device=layer_in.device)
136+
norm_layer_in[:, text_indices, :] = self.input_layernorm(layer_in[:, text_indices, :])
137+
norm_layer_in[:, image_indices, :] = self.input_layernorm_moe_gen(layer_in[:, image_indices, :])
138+
else:
139+
norm_layer_in = self.input_layernorm(layer_in)
140+
141+
print("NORM_LAYER_IN:", norm_layer_in.shape, norm_layer_in.sum(), norm_layer_in)
142+
print("NORM_LAYER_IN img:", norm_layer_in[:, -4098:, :].shape, norm_layer_in[:, -4098:, :].sum(), norm_layer_in[:, -4098:, :])
115143

116144
self_attn, attns = self.self_attn(
117145
norm_layer_in,
118146
attn_mask=attn_mask,
119147
step=step,
120148
return_attn=return_attn,
121149
position_embeddings=position_embeddings,
150+
text_indices=text_indices,
151+
image_indices=image_indices,
122152
)
123153

154+
# print("SELF_ATTN:", self_attn.shape, self_attn.sum(), self_attn)
155+
# print("SELF_ATTN img:", self_attn[:, -4098:, :].shape, self_attn[:, -4098:, :].sum(), self_attn[:, -4098:, :])
156+
124157
if self.dropout_p > 0:
125158
self_attn = self.dropout(self_attn)
126159

@@ -130,6 +163,8 @@ def forward(self, layer_in, **kwargs):
130163
layer_out = ff_in + self._mlp(ff_in)
131164
return layer_out, attns
132165

166+
text_sequence, image_sequence = None, None # dirty patch
167+
133168
if self.parallel_residual:
134169
if self.context_attn:
135170
ctx_attn, attns = self.context_attn(
@@ -160,9 +195,27 @@ def forward(self, layer_in, **kwargs):
160195
ctx_attn = self.dropout(ctx_attn)
161196
else:
162197
ctx_attn = 0
163-
ff_in = self.post_attention_layernorm(ctx_attn + self_attn + layer_in)
198+
sequence = ctx_attn + self_attn + layer_in
199+
if self.image_generation:
200+
text_sequence = sequence[:, text_indices, :]
201+
image_sequence = sequence[:, image_indices, :]
202+
text_sequence = self.post_attention_layernorm(text_sequence)
203+
image_sequence = self.post_attention_layernorm_moe_gen(image_sequence)
204+
# print("POST_ATTENTION_LAYER_NORM text:", text_sequence.shape, text_sequence.sum(), text_sequence)
205+
# print("POST_ATTENTION_LAYER_NORM img:", image_sequence.shape, image_sequence.sum(), image_sequence)
206+
else:
207+
ff_in = self.post_attention_layernorm(sequence)
164208
# we apply residual with un-normed
165-
MLP = self.mlp(ff_in)
209+
if self.image_generation:
210+
MLP = torch.zeros_like(sequence, dtype=sequence.dtype, device=sequence.device)
211+
MLP[:, text_indices, :] = self.mlp(text_sequence)
212+
MLP[:, image_indices, :] = self.mlp_moe_gen(image_sequence)
213+
# print("MLP text:", MLP[:, text_indices, :].shape, MLP[:, text_indices, :].sum(), MLP[:, text_indices, :])
214+
# print("MLP img:", MLP[:, image_indices, :].shape, MLP[:, image_indices, :].sum(), MLP[:, image_indices, :])
215+
else:
216+
MLP = self.mlp(ff_in)
217+
# print("MLP:", MLP.shape, MLP.sum(), MLP)
218+
# print("MLP img:", MLP[:, -4098:, :].shape, MLP[:, -4098:, :].sum(), MLP[:, -4098:, :])
166219
layer_out = MLP + layer_in + self_attn + ctx_attn
167220

168221
return layer_out, attns
@@ -227,7 +280,13 @@ def __init__(
227280
for i in range(decoder_config.layers)
228281
]
229282
)
283+
self.image_generation = getattr(running_config, "image_generation", False)
230284
self.layer_norm = LayerNorm[decoder_config.layer_norm](decoder_config.hidden_size, eps=decoder_config.norm_eps)
285+
if self.image_generation:
286+
# initialize MOE GEN params
287+
self.layer_norm_moe_gen = LayerNorm[decoder_config.layer_norm](
288+
decoder_config.hidden_size, eps=decoder_config.norm_eps
289+
)
231290
self._disable_cache()
232291

233292
@classmethod
@@ -268,6 +327,8 @@ def _causal_attn_mask(self, tgt_pad_mask):
268327
)
269328
if self.sliding_window > 0:
270329
future_mask = future_mask.triu_(-self.sliding_window)
330+
# print("future_mask", future_mask.shape, future_mask.dtype, future_mask.device)
331+
# print("tgt_pad_mask", tgt_pad_mask.shape, tgt_pad_mask.dtype, tgt_pad_mask.device)
271332
attn_mask = ~tgt_pad_mask & future_mask.unsqueeze(0)
272333
return attn_mask.unsqueeze(1) # (batch x 1 x 1 x tgt_len)
273334

@@ -314,7 +375,11 @@ def forward(self, emb, **kwargs):
314375
with_align = kwargs.pop("with_align", False)
315376
return_attn = with_align or kwargs.pop("return_attn", False)
316377
positions = kwargs.pop("positions", None)
378+
# print("positions", positions)
317379
position_embeddings = self.rope.update(emb.size(1), step=step, positions=positions)
380+
cos, sin = position_embeddings
381+
# print("COS:", cos.shape, cos.sum(), cos)
382+
# print("SIN:", sin.shape, sin.sum(), sin)
318383
if self.rope_local is not None:
319384
position_embeddings_local = self.rope_local.update(emb.size(1), step=step)
320385
else:
@@ -341,12 +406,19 @@ def forward(self, emb, **kwargs):
341406

342407
# we need to adapt the mask for gemma3, TODO: find another condition?
343408
# SEEMS OK TO MASK IMAGES FOR LLAVA TOO ?
409+
# print("ATTN_MASK before update", attn_mask.shape, attn_mask)
344410
if decoder_in is not None and attn_mask is not None:
411+
# print("DECODER_IN:", decoder_in)
345412
attn_mask = self._update_causal_mask(attn_mask, (decoder_in == image_token_id) | (decoder_in == 151652) | (decoder_in == 151653))
413+
# print("ATTN_MASK after update", attn_mask.shape, attn_mask)
414+
415+
416+
346417
if self.sliding_window > 0 and step >= self.sliding_window and attn_mask is not None:
347418
attn_mask = attn_mask[:, :, :, -self.sliding_window :]
348419

349420
for i, layer in enumerate(self.transformer_layers):
421+
print(f"\n=================\nLAYER {i}\n=================\n")
350422
emb, attn = layer(
351423
emb,
352424
enc_out=enc_out if enc_out is not None else emb,
@@ -357,7 +429,9 @@ def forward(self, emb, **kwargs):
357429
position_embeddings=(
358430
position_embeddings_local if (i + 1) % self.interleave_local else position_embeddings
359431
),
432+
**kwargs,
360433
)
434+
print("EMB:", emb.shape, emb.sum(), emb)
361435
if with_align:
362436
attn_align = layer.get_attn_align(
363437
emb,
@@ -374,7 +448,19 @@ def forward(self, emb, **kwargs):
374448
if attn_align is not None:
375449
attn_aligns.append(attn_align)
376450

377-
emb = self.layer_norm(emb)
451+
452+
# TODO apply MOE logic here
453+
if self.image_generation:
454+
emb_ = torch.zeros_like(emb, dtype=emb.dtype, device=emb.device)
455+
text_indices = kwargs.get("text_indices", None)
456+
image_indices = kwargs.get("image_indices", None)
457+
assert text_indices is not None, "Text indices must be provided for image generation"
458+
assert image_indices is not None, "Image indices must be provided for image generation"
459+
emb_[:, text_indices, :] = self.layer_norm(emb[:, text_indices, :])
460+
emb_[:, image_indices, :] = self.layer_norm_moe_gen(emb[:, image_indices, :])
461+
emb = emb_
462+
else:
463+
emb = self.layer_norm(emb)
378464

379465
# we take the first head
380466
top_attn = None if attn is None else attn[:, 0, :, :].contiguous()

eole/encoders/transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
self.mlp = MLP(
4646
encoder_config,
4747
running_config=running_config,
48+
is_decoder=False,
4849
)
4950

5051
def forward(self, layer_in, pad_mask, position_embeddings=None):

0 commit comments

Comments
 (0)