Skip to content

Commit 465db9b

Browse files
WIP bagel vision understanding patches
1 parent 11ab596 commit 465db9b

File tree

14 files changed

+813
-59
lines changed

14 files changed

+813
-59
lines changed

eole/bin/convert/HF_mappings.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,45 @@
3737
MODEL_OVERRIDES = {
3838
"LlamaForCausalLM": {}, # default
3939
"MistralForCausalLM": {},
40-
"Qwen2ForCausalLM": {
40+
"Qwen2ForCausalLM": { # for bagel, but we need to add some conditions to keep supporting real qwen2...
41+
"decoder_layer_prefix": "language_model.model.layers.",
42+
"decoder.layer_norm.weight": "language_model.model.norm.weight",
43+
"encoder_layer_prefix": "vit_model.vision_model.encoder.layers.",
44+
"encoder.patch_conv.weight": "vit_model.vision_model.embeddings.patch_embedding.weight",
45+
"encoder.patch_conv.bias": "vit_model.vision_model.embeddings.patch_embedding.bias",
46+
"encoder.position_embeddings.weight": "vit_model.vision_model.embeddings.position_embedding.weight",
47+
"encoder.post_layernorm.weight": "vit_model.vision_model.post_layernorm.weight",
48+
"encoder.post_layernorm.bias": "vit_model.vision_model.post_layernorm.bias",
49+
"tgt_emb.embeddings.weight": "language_model.model.embed_tokens.weight",
50+
"generator.weight": "language_model.lm_head.weight",
51+
# vision_adapter
52+
"adapter.w_in.weight": "connector.fc1.weight",
53+
"adapter.w_in.bias": "connector.fc1.bias",
54+
"adapter.w_out.weight": "connector.fc2.weight",
55+
"adapter.w_out.bias": "connector.fc2.bias",
56+
"vit_pos_embed.pos_embed": "vit_pos_embed.pos_embed",
57+
"decoder": {
58+
".self_attn.q_norm.": ".self_attn.q_norm.",
59+
".self_attn.k_norm.": ".self_attn.k_norm.",
60+
},
61+
"encoder": {
62+
".self_attn.linear_query.": ".self_attn.q_proj.",
63+
".self_attn.linear_keys.": ".self_attn.k_proj.",
64+
".self_attn.linear_values.": ".self_attn.v_proj.",
65+
".self_attn.final_linear.": ".self_attn.out_proj.",
66+
".mlp.gate_up_proj.": ".mlp.fc1.",
67+
".mlp.down_proj.": ".mlp.fc2.",
68+
".input_layernorm.": ".layer_norm1.",
69+
".post_attention_layernorm.": ".layer_norm2.",
70+
},
4171
"config": {
4272
"add_qkvbias": True,
4373
"add_final_linear_bias": False,
74+
# "ffn_layernorm": True,
75+
"decoder": {
76+
"query_norm": True,
77+
"key_norm": True,
78+
},
4479
}
4580
},
4681
"Qwen3ForCausalLM": {
@@ -353,6 +388,7 @@
353388
"Gemma2ForCausalLM": "gemma-rms",
354389
"M2M100ForConditionalGeneration": "standard",
355390
"Gemma3ForConditionalGeneration": "gemma-rms",
391+
"Qwen2ForCausalLM": "rms",
356392
},
357393
)
358394

@@ -386,6 +422,7 @@
386422
"Mistral3ForConditionalGeneration": VisionTransformerLMModelConfig,
387423
"Gemma3ForConditionalGeneration": VisionTransformerLMModelConfig,
388424
"M2M100ForConditionalGeneration": TransformerModelConfig,
425+
"Qwen2ForCausalLM": VisionTransformerLMModelConfig,
389426
},
390427
)
391428

eole/bin/convert/convert_HF.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class HuggingfaceFiles:
6666
wmap_path: Optional[str] = None
6767
model_path: Optional[str] = None
6868
special_tokens_json: Optional[str] = None
69+
vision_config_path: Optional[str] = None
6970

7071
# Unified dictionary to cache loaded files
7172
_loaded_files: dict = field(default_factory=dict, init=False)
@@ -117,7 +118,7 @@ def download_file_from_hub(file_name, required=True):
117118

118119
# Fetch required and optional files
119120
paths = {
120-
"config_path": get_file_fn("config.json", required=True),
121+
"config_path": get_file_fn("llm_config.json", required=False), # hard patch for bagel
121122
"tokenizer_config_json": get_file_fn("tokenizer_config.json", required=True),
122123
"generation_config_json": get_file_fn("generation_config.json", required=False),
123124
"tokenizer_model": get_file_fn("tokenizer.model", required=False)
@@ -126,8 +127,9 @@ def download_file_from_hub(file_name, required=True):
126127
"wmap_path": get_file_fn("model.safetensors.index.json", required=False)
127128
or get_file_fn("pytorch_model.bin.index.json", required=False),
128129
"model_path": get_file_fn("model.safetensors", required=False)
129-
or get_file_fn("pytorch_model.bin", required=False),
130+
or get_file_fn("pytorch_model.bin", required=False) or get_file_fn("ema.safetensors", required=False),
130131
"special_tokens_json": get_file_fn("special_tokens_map.json", required=False),
132+
"vision_config_path": get_file_fn("vit_config.json", required=False),
131133
}
132134

133135
return cls(**paths, model_dir=args.model_dir, token=args.token)
@@ -270,9 +272,13 @@ def build_config_dict(hf):
270272
arch = hf.arch
271273
print("Architecture: ", arch)
272274

273-
vision_config = config.get("vision_config", None)
274-
other_config = config # save what is not text/vision for later use
275-
config = config.get("text_config", config)
275+
vision_config = getattr(hf, "vision_config", None)
276+
if vision_config is None:
277+
vision_config = config.get("vision_config", None)
278+
other_config = config # save what is not text/vision for later use
279+
config = config.get("text_config", config)
280+
281+
print("VISION_CONFIG:", vision_config)
276282

277283
model_config = {}
278284
training_config = {}
@@ -289,6 +295,7 @@ def build_config_dict(hf):
289295
"transformer_ff_moe": config.get("moe_intermediate_size", None),
290296
"mlp_activation_fn": ACT_TABLE[arch],
291297
"layer_norm": LN_TABLE[arch],
298+
# TODO: this can break encoder (e.g. bagel)
292299
"heads_kv": config.get("multi_query", False)
293300
or config.get(
294301
"num_key_value_heads",
@@ -351,6 +358,30 @@ def build_config_dict(hf):
351358
model_config["projector_activation_fn"] = other_config.get("projector_hidden_act", "gelu")
352359
model_config["spatial_merge_size"] = other_config.get("spatial_merge_size", None)
353360

361+
if arch == "Qwen2ForCausalLM":
362+
model_config["adapter"] = "bagel"
363+
model_config["encoder"] = {
364+
"mlp_activation_fn": "gelu-tanh", # no up_proj it seems
365+
"hidden_size": vision_config.get("hidden_size", 1152),
366+
# "image_size": vision_config["image_size"],
367+
"image_size": 1024,
368+
"patch_size": vision_config["patch_size"],
369+
"heads": vision_config["num_attention_heads"],
370+
"heads_kv": vision_config["num_attention_heads"],
371+
"layers": 26, # 27 in config, but actually 26 in safetensors...
372+
"transformer_ff": vision_config["intermediate_size"],
373+
# siglip style learned position embeddings (like gemma3)
374+
"position_encoding_type": PositionEncodingType.Learned,
375+
"n_positions": (vision_config["image_size"] // vision_config["patch_size"]) ** 2,
376+
"add_ffnbias": True,
377+
"add_final_linear_bias": True,
378+
"add_qkvbias": True,
379+
"layer_norm": "standard",
380+
"patch_conv_bias": True,
381+
"layernorm_pre": False, # implies post layernorm
382+
"image_token_id": 151654,
383+
}
384+
354385
if arch == "Gemma3ForConditionalGeneration":
355386
if model_config.get("head_dim", None) is None:
356387
model_config["head_dim"] = 256 # src/transformers/models/gemma3/configuration_gemma3.py#L61

eole/decoders/transformer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def forward(self, layer_in, **kwargs):
110110
return_attn = kwargs.pop("return_attn", False)
111111
position_embeddings = kwargs.pop("position_embeddings", None)
112112

113+
113114
norm_layer_in = self.input_layernorm(layer_in)
114115

115116
self_attn, attns = self.self_attn(
@@ -161,7 +162,8 @@ def forward(self, layer_in, **kwargs):
161162
ctx_attn = 0
162163
ff_in = self.post_attention_layernorm(ctx_attn + self_attn + layer_in)
163164
# we apply residual with un-normed
164-
layer_out = self.mlp(ff_in) + layer_in + self_attn + ctx_attn
165+
MLP = self.mlp(ff_in)
166+
layer_out = MLP + layer_in + self_attn + ctx_attn
165167

166168
return layer_out, attns
167169

@@ -311,7 +313,8 @@ def forward(self, emb, **kwargs):
311313
step = kwargs.pop("step", None)
312314
with_align = kwargs.pop("with_align", False)
313315
return_attn = with_align or kwargs.pop("return_attn", False)
314-
position_embeddings = self.rope.update(emb.size(1), step=step)
316+
positions = kwargs.pop("positions", None)
317+
position_embeddings = self.rope.update(emb.size(1), step=step, positions=positions)
315318
if self.rope_local is not None:
316319
position_embeddings_local = self.rope_local.update(emb.size(1), step=step)
317320
else:
@@ -339,7 +342,7 @@ def forward(self, emb, **kwargs):
339342
# we need to adapt the mask for gemma3, TODO: find another condition?
340343
# SEEMS OK TO MASK IMAGES FOR LLAVA TOO ?
341344
if decoder_in is not None and attn_mask is not None:
342-
attn_mask = self._update_causal_mask(attn_mask, decoder_in == image_token_id)
345+
attn_mask = self._update_causal_mask(attn_mask, (decoder_in == image_token_id) | (decoder_in == 151652) | (decoder_in == 151653))
343346
if self.sliding_window > 0 and step >= self.sliding_window and attn_mask is not None:
344347
attn_mask = attn_mask[:, :, :, -self.sliding_window :]
345348

eole/encoders/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
)
4141
self.dropout = nn.Dropout(self.dropout_p)
4242
self.post_attention_layernorm = LayerNorm[encoder_config.layer_norm](
43-
encoder_config.hidden_size, eps=encoder_config.norm_eps
43+
encoder_config.hidden_size, eps=encoder_config.norm_eps, bias=True
4444
)
4545
self.mlp = MLP(
4646
encoder_config,

eole/encoders/vision.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,15 @@ def position_ids_in_meshgrid(patch_embeds_list, max_width, flatten=True):
6767
return torch.stack(positions)
6868

6969

70+
71+
# from bagel
72+
def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side):
73+
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
74+
coords_h = torch.arange(0, num_patches_h)
75+
coords_w = torch.arange(0, num_patches_w)
76+
pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
77+
return pos_ids
78+
7079
def create_block_diagonal_mask(lengths, device):
7180
"""
7281
Create a block diagonal mask based on sequence lengths.
@@ -88,6 +97,18 @@ def create_block_diagonal_mask(lengths, device):
8897
return mask.to(device)
8998

9099

100+
# grabbed from bagel repo
101+
102+
def patchify(image, patch_size):
103+
p = patch_size
104+
c, h, w = image.shape
105+
assert h % p == 0 and w % p == 0
106+
image = image.reshape(c, h // p, p, w // p, p)
107+
image = torch.einsum("chpwq->hwpqc", image)
108+
image = image.reshape(-1, p**2 * c)
109+
return image
110+
111+
91112
class VisionEncoder(nn.Module):
92113
def __init__(self, encoder_config, running_config=None):
93114
super(VisionEncoder, self).__init__()
@@ -99,12 +120,18 @@ def __init__(self, encoder_config, running_config=None):
99120
)
100121
else:
101122
self.rope = build_rope(encoder_config, mode="2d")
102-
self.patch_conv = nn.Conv2d(
103-
in_channels=encoder_config.num_channels,
104-
out_channels=encoder_config.hidden_size,
105-
kernel_size=encoder_config.patch_size,
106-
stride=encoder_config.patch_size,
107-
bias=encoder_config.patch_conv_bias,
123+
# self.patch_conv = nn.Conv2d(
124+
# in_channels=encoder_config.num_channels,
125+
# out_channels=encoder_config.hidden_size,
126+
# kernel_size=encoder_config.patch_size,
127+
# stride=encoder_config.patch_size,
128+
# bias=encoder_config.patch_conv_bias,
129+
# )
130+
# linear patch conv for bagel
131+
self.patch_conv = nn.Linear(
132+
encoder_config.patch_size * encoder_config.patch_size * encoder_config.num_channels,
133+
encoder_config.hidden_size,
134+
bias=True,
108135
)
109136
if encoder_config.layernorm_pre:
110137
self.ln_pre = RMSNorm(encoder_config.hidden_size, eps=1e-5)
@@ -133,7 +160,8 @@ def from_config(cls, encoder_config, running_config=None):
133160

134161
@property
135162
def max_patches_per_side(self):
136-
return self.encoder_config.image_size // self.encoder_config.patch_size
163+
return 70 # hardcoded bagel value
164+
# return self.encoder_config.image_size // self.encoder_config.patch_size
137165

138166
@property
139167
def device(self):
@@ -151,8 +179,10 @@ def forward(self, images):
151179
# TODO add as @property somewhere
152180
dtype = next(self.parameters()).dtype
153181

182+
pixel_values = [patchify(img, self.encoder_config.patch_size) for img in images]
183+
154184
# pass images through initial convolution independently (because they may have different sizes)
155-
patch_embeds_list = [self.patch_conv(img.to(dtype)) for img in images]
185+
patch_embeds_list = [self.patch_conv(pv.to(dtype)) for pv in pixel_values]
156186

157187
if self.ln_pre is not None: # pixtral / mistral
158188
# flatten H+W then change to (H+W, C) and stack all images of ex
@@ -171,17 +201,32 @@ def forward(self, images):
171201
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
172202
mask = None
173203

204+
patch_embeds = patch_embeds.transpose(1, 2) # (N_img, Seqlen, D)
205+
174206
# positional embeddings
175-
positions = position_ids_in_meshgrid(
176-
patch_embeds_list,
177-
max_width=self.encoder_config.image_size // self.encoder_config.patch_size,
178-
flatten=self.ln_pre is not None, # dirty flag need to improve
179-
).to(self.device)
207+
# positions = position_ids_in_meshgrid(
208+
# # patch_embeds_list,
209+
# images,
210+
# max_width=self.encoder_config.image_size // self.encoder_config.patch_size,
211+
# flatten=self.ln_pre is not None, # dirty flag need to improve
212+
# ).to(self.device)
213+
positions = torch.cat([
214+
get_flattened_position_ids_extrapolate(
215+
img.shape[-2],
216+
img.shape[-1],
217+
self.encoder_config.patch_size,
218+
self.max_patches_per_side,
219+
220+
)
221+
for img in images
222+
], axis=0).unsqueeze(0).to(self.device)
223+
180224
# TODO: make this cleaner
181225
if hasattr(self, "position_embeddings"):
182226
# this is only used for rope
183227
position_embeddings = None
184-
patch_embeds += self.position_embeddings(positions)
228+
pos_embeds = self.position_embeddings(positions)
229+
patch_embeds += pos_embeds
185230
else:
186231
position_embeddings = self.rope.update(
187232
patch_embeds.size(1),
@@ -197,7 +242,7 @@ def forward(self, images):
197242
if self.post_layernorm is not None:
198243
out = self.post_layernorm(out)
199244

200-
return out
245+
return out, positions
201246

202247

203248
# Multi-Modal Projector
@@ -266,4 +311,5 @@ def from_config(cls, model_config, running_config=None):
266311
str2adapter = {
267312
"llava": VisionLanguageAdapter,
268313
"gemma3": Gemma3MultiModalProjector,
314+
"bagel": VisionLanguageAdapter,
269315
}

0 commit comments

Comments
 (0)