Skip to content

Commit d156b05

Browse files
clean up and structure
1 parent 6983164 commit d156b05

23 files changed

+1047
-966
lines changed

eole/bin/convert/HF_mappings.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
MODEL_OVERRIDES = {
3838
"LlamaForCausalLM": {}, # default
3939
"MistralForCausalLM": {},
40-
"Qwen2ForCausalLM": { # for bagel, but we need to add some conditions to keep supporting real qwen2...
40+
"Bagel": { # bagel's arch is actually Qwen2, but requires specific mapping
4141
"decoder_layer_prefix": "language_model.model.layers.",
4242
"decoder.layer_norm.weight": "language_model.model.norm.weight",
4343
"decoder.layer_norm_moe_gen.weight": "language_model.model.norm_moe_gen.weight",
@@ -95,12 +95,23 @@
9595
"config": {
9696
"add_qkvbias": True,
9797
"add_final_linear_bias": False,
98-
# "ffn_layernorm": True,
98+
"adapter": "bagel",
99+
"vit_position_embeddings": True,
99100
"decoder": {
100101
"query_norm": True,
101102
"key_norm": True,
102103
},
103-
}
104+
"encoder": {
105+
"mlp_activation_fn": "gelu-tanh",
106+
"add_ffnbias": True,
107+
"add_final_linear_bias": True,
108+
"add_qkvbias": True,
109+
"layer_norm": "standard",
110+
"patch_conv_bias": True,
111+
"patch_conv_linear": True,
112+
"layernorm_pre": False, # implies post layernorm
113+
},
114+
},
104115
},
105116
"Qwen3ForCausalLM": {
106117
"decoder": {
@@ -412,7 +423,6 @@
412423
"Gemma2ForCausalLM": "gemma-rms",
413424
"M2M100ForConditionalGeneration": "standard",
414425
"Gemma3ForConditionalGeneration": "gemma-rms",
415-
"Qwen2ForCausalLM": "rms",
416426
},
417427
)
418428

@@ -446,7 +456,7 @@
446456
"Mistral3ForConditionalGeneration": VisionTransformerLMModelConfig,
447457
"Gemma3ForConditionalGeneration": VisionTransformerLMModelConfig,
448458
"M2M100ForConditionalGeneration": TransformerModelConfig,
449-
"Qwen2ForCausalLM": VisionTransformerLMModelConfig,
459+
"Bagel": VisionTransformerLMModelConfig,
450460
},
451461
)
452462

eole/bin/convert/convert_HF.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def download_file_from_hub(file_name, required=True):
119119

120120
# Fetch required and optional files
121121
paths = {
122-
"config_path": get_file_fn("llm_config.json", required=False), # hard patch for bagel
122+
"config_path": get_file_fn("llm_config.json", required=False) or get_file_fn("config.json", required=True),
123123
"tokenizer_config_json": get_file_fn("tokenizer_config.json", required=True),
124124
"generation_config_json": get_file_fn("generation_config.json", required=False),
125125
"tokenizer_model": get_file_fn("tokenizer.model", required=False)
@@ -128,7 +128,8 @@ def download_file_from_hub(file_name, required=True):
128128
"wmap_path": get_file_fn("model.safetensors.index.json", required=False)
129129
or get_file_fn("pytorch_model.bin.index.json", required=False),
130130
"model_path": get_file_fn("model.safetensors", required=False)
131-
or get_file_fn("pytorch_model.bin", required=False) or get_file_fn("ema.safetensors", required=False),
131+
or get_file_fn("pytorch_model.bin", required=False)
132+
or get_file_fn("ema.safetensors", required=False),
132133
"special_tokens_json": get_file_fn("special_tokens_map.json", required=False),
133134
"vision_config_path": get_file_fn("vit_config.json", required=False),
134135
"ae_model_path": get_file_fn("ae.safetensors", required=False),
@@ -162,6 +163,8 @@ def __getattr__(self, name):
162163

163164
@property
164165
def arch(self):
166+
if self.model_dir == "ByteDance-Seed/BAGEL-7B-MoT":
167+
return "Bagel"
165168
return self.config["architectures"][0]
166169

167170
@property
@@ -280,8 +283,6 @@ def build_config_dict(hf):
280283
other_config = config # save what is not text/vision for later use
281284
config = config.get("text_config", config)
282285

283-
print("VISION_CONFIG:", vision_config)
284-
285286
model_config = {}
286287
training_config = {}
287288

@@ -360,28 +361,22 @@ def build_config_dict(hf):
360361
model_config["projector_activation_fn"] = other_config.get("projector_hidden_act", "gelu")
361362
model_config["spatial_merge_size"] = other_config.get("spatial_merge_size", None)
362363

363-
if arch == "Qwen2ForCausalLM":
364-
model_config["adapter"] = "bagel"
364+
if arch == "Bagel":
365365
model_config["encoder"] = {
366-
"mlp_activation_fn": "gelu-tanh", # no up_proj it seems
367366
"hidden_size": vision_config.get("hidden_size", 1152),
368-
# "image_size": vision_config["image_size"],
369-
"image_size": 1024,
367+
"image_size": 1024, # 980 for VIT (vit_config.json), 1024 for VAE
370368
"patch_size": vision_config["patch_size"],
371369
"heads": vision_config["num_attention_heads"],
372370
"heads_kv": vision_config["num_attention_heads"],
373-
"layers": 26, # 27 in config, but actually 26 in safetensors...
371+
"layers": 26, # 27 in config, but actually 26 in safetensors
374372
"transformer_ff": vision_config["intermediate_size"],
375373
# siglip style learned position embeddings (like gemma3)
376374
"position_encoding_type": PositionEncodingType.Learned,
377375
"n_positions": (vision_config["image_size"] // vision_config["patch_size"]) ** 2,
378-
"add_ffnbias": True,
379-
"add_final_linear_bias": True,
380-
"add_qkvbias": True,
381-
"layer_norm": "standard",
382-
"patch_conv_bias": True,
383-
"layernorm_pre": False, # implies post layernorm
384376
"image_token_id": 151654,
377+
"image_start_token_id": 151652,
378+
"image_end_token_id": 151653,
379+
"max_patches_per_side": 70,
385380
}
386381

387382
if arch == "Gemma3ForConditionalGeneration":
@@ -679,7 +674,7 @@ def build_shards(model_config, hf, args, params):
679674
eole_safetensor = {}
680675

681676
def build_first_shard(hf, eole_safetensor):
682-
# let's add AE here
677+
# let's add AE here (visual autoencoder for image generation)
683678
if hf.ae_model_path is not None:
684679
ae_checkpoint = hf.get_load_ckpt(*os.path.split(hf.ae_model_path))
685680
ae_params = safetensors.torch.load_file(ae_checkpoint)

eole/config/inference.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -105,31 +105,12 @@ class ImageGenerationConfig(Config):
105105
description="Height of the generated image. "
106106
"This will only work if the model is trained for image generation.",
107107
)
108-
cfg_text_scale: float | None = Field(
109-
default=1.0,
110-
description="Classifier-free guidance scale for text input. "
111-
)
112-
cfg_image_scale: float | None = Field(
113-
default=1.0,
114-
description="Classifier-free guidance scale for image input. "
115-
)
116-
cfg_interval_min: float | None = Field(
117-
default=0.0,
118-
description="Minimum classifier-free guidance interval. "
119-
)
120-
cfg_interval_max: float | None = Field(
121-
default=1.0,
122-
description="Maximum classifier-free guidance interval. "
123-
)
124-
timestep_shift: float | None = Field(
125-
default=1.0,
126-
description="Shift the timestep for image generation. "
127-
)
128-
num_timesteps: int | None = Field(
129-
default=50,
130-
description="Number of timesteps for image generation. "
131-
)
132-
108+
cfg_text_scale: float | None = Field(default=1.0, description="Classifier-free guidance scale for text input. ")
109+
cfg_image_scale: float | None = Field(default=1.0, description="Classifier-free guidance scale for image input. ")
110+
cfg_interval_min: float | None = Field(default=0.0, description="Minimum classifier-free guidance interval. ")
111+
cfg_interval_max: float | None = Field(default=1.0, description="Maximum classifier-free guidance interval. ")
112+
timestep_shift: float | None = Field(default=1.0, description="Shift the timestep for image generation. ")
113+
num_timesteps: int | None = Field(default=50, description="Number of timesteps for image generation. ")
133114

134115

135116
# in legacy opts, decoding config is separated (probably to be used elsewhere)

eole/config/models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,10 +371,17 @@ class VisionEncoderConfig(TransformerConfig, EncoderConfig):
371371
num_channels: int | None = 3
372372
image_size: int | None = 1024
373373
patch_size: int | None = 16
374+
max_patches_per_side: int | None = None
375+
max_latent_size: int | None = 64 # bagel
376+
latent_patch_size: int | None = 2
377+
latent_channel: int | None = 16
374378
image_token_id: int | None = 10 # pixtral uses 10, gemma3 uses 262144
379+
image_start_token_id: int | None = None
380+
image_end_token_id: int | None = None
375381
mm_tokens_per_image: int | None = 256 # added for gemma3
376382
layernorm_pre: bool = True # True for pixtral/mistral False for gemma3
377383
patch_conv_bias: bool = False # False for pixtral/mistral True for gemma3
384+
patch_conv_linear: bool = False # False for pixtral/gemma3 True for bagel
378385

379386

380387
# use Field with default= + description would be more readable
@@ -771,6 +778,10 @@ class VisionTransformerLMModelConfig(TransformerConfig, BaseModelConfig):
771778

772779
adapter: str | None = Field(default="llava", description="Adapter type to use in the model.")
773780

781+
vit_position_embeddings: bool = Field(
782+
default=False, description="Additional position embeddings for images, introduced for Bagel."
783+
)
784+
774785
@model_validator(mode="before")
775786
@classmethod
776787
def encoder_decoder_type(cls, data: Any) -> Any:

eole/decoders/decoder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ def __init__(self, attentional=True):
1414
# Decoder state
1515
self.state = {}
1616

17+
@property
18+
def device(self):
19+
return next(self.parameters()).device
20+
1721
@classmethod
1822
def from_config(cls, decoder_config, running_config=None, with_cross_attn=False):
1923
"""Alternate constructor.

0 commit comments

Comments
 (0)