Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 73 additions & 2 deletions eole/bin/convert/HF_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,81 @@
MODEL_OVERRIDES = {
"LlamaForCausalLM": {}, # default
"MistralForCausalLM": {},
"Qwen2ForCausalLM": {
"Bagel": { # bagel's arch is actually Qwen2, but requires specific mapping
"decoder_layer_prefix": "language_model.model.layers.",
"decoder.layer_norm.weight": "language_model.model.norm.weight",
"decoder.layer_norm_moe_gen.weight": "language_model.model.norm_moe_gen.weight",
"encoder_layer_prefix": "vit_model.vision_model.encoder.layers.",
"encoder.patch_conv.weight": "vit_model.vision_model.embeddings.patch_embedding.weight",
"encoder.patch_conv.bias": "vit_model.vision_model.embeddings.patch_embedding.bias",
"encoder.position_embeddings.weight": "vit_model.vision_model.embeddings.position_embedding.weight",
"encoder.post_layernorm.weight": "vit_model.vision_model.post_layernorm.weight",
"encoder.post_layernorm.bias": "vit_model.vision_model.post_layernorm.bias",
"tgt_emb.embeddings.weight": "language_model.model.embed_tokens.weight",
"generator.weight": "language_model.lm_head.weight",
# vision_adapter
"adapter.w_in.weight": "connector.fc1.weight",
"adapter.w_in.bias": "connector.fc1.bias",
"adapter.w_out.weight": "connector.fc2.weight",
"adapter.w_out.bias": "connector.fc2.bias",
# additional stuff, mostly replicated as-is for now
"vit_pos_embed.pos_embed": "vit_pos_embed.pos_embed",
"latent_pos_embed.pos_embed": "latent_pos_embed.pos_embed",
"time_embedder.mlp.0.weight": "time_embedder.mlp.0.weight",
"time_embedder.mlp.0.bias": "time_embedder.mlp.0.bias",
"time_embedder.mlp.2.weight": "time_embedder.mlp.2.weight",
"time_embedder.mlp.2.bias": "time_embedder.mlp.2.bias",
"vae2llm.weight": "vae2llm.weight",
"vae2llm.bias": "vae2llm.bias",
"llm2vae.weight": "llm2vae.weight",
"llm2vae.bias": "llm2vae.bias",
# TODO: not sure how to properly grab VAE stuff
"decoder": {
".self_attn.q_norm.": ".self_attn.q_norm.",
".self_attn.k_norm.": ".self_attn.k_norm.",
# MOE GEN (simplify with loop?)
".self_attn.q_norm_moe_gen.": ".self_attn.q_norm_moe_gen.",
".self_attn.k_norm_moe_gen.": ".self_attn.k_norm_moe_gen.",
".self_attn.linear_query_moe_gen.": ".self_attn.q_proj_moe_gen.",
".self_attn.linear_keys_moe_gen.": ".self_attn.k_proj_moe_gen.",
".self_attn.linear_values_moe_gen.": ".self_attn.v_proj_moe_gen.",
".self_attn.final_linear_moe_gen.": ".self_attn.o_proj_moe_gen.",
".mlp_moe_gen.gate_up_proj.": ".mlp_moe_gen.gate_proj.",
".mlp_moe_gen.down_proj.": ".mlp_moe_gen.down_proj.",
".mlp_moe_gen.up_proj.": ".mlp_moe_gen.up_proj.",
".input_layernorm_moe_gen.": ".input_layernorm_moe_gen.",
".post_attention_layernorm_moe_gen.": ".post_attention_layernorm_moe_gen.",
},
"encoder": {
".self_attn.linear_query.": ".self_attn.q_proj.",
".self_attn.linear_keys.": ".self_attn.k_proj.",
".self_attn.linear_values.": ".self_attn.v_proj.",
".self_attn.final_linear.": ".self_attn.out_proj.",
".mlp.gate_up_proj.": ".mlp.fc1.",
".mlp.down_proj.": ".mlp.fc2.",
".input_layernorm.": ".layer_norm1.",
".post_attention_layernorm.": ".layer_norm2.",
},
"config": {
"add_qkvbias": True,
"add_final_linear_bias": False,
}
"adapter": "bagel",
"vit_position_embeddings": True,
"decoder": {
"query_norm": True,
"key_norm": True,
},
"encoder": {
"mlp_activation_fn": "gelu-tanh",
"add_ffnbias": True,
"add_final_linear_bias": True,
"add_qkvbias": True,
"layer_norm": "standard",
"patch_conv_bias": True,
"patch_conv_linear": True,
"layernorm_pre": False, # implies post layernorm
},
},
},
"Qwen3ForCausalLM": {
"decoder": {
Expand Down Expand Up @@ -386,6 +456,7 @@
"Mistral3ForConditionalGeneration": VisionTransformerLMModelConfig,
"Gemma3ForConditionalGeneration": VisionTransformerLMModelConfig,
"M2M100ForConditionalGeneration": TransformerModelConfig,
"Bagel": VisionTransformerLMModelConfig,
},
)

Expand Down
45 changes: 40 additions & 5 deletions eole/bin/convert/convert_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class HuggingfaceFiles:
wmap_path: Optional[str] = None
model_path: Optional[str] = None
special_tokens_json: Optional[str] = None
vision_config_path: Optional[str] = None
ae_model_path: Optional[str] = None

# Unified dictionary to cache loaded files
_loaded_files: dict = field(default_factory=dict, init=False)
Expand Down Expand Up @@ -117,7 +119,7 @@ def download_file_from_hub(file_name, required=True):

# Fetch required and optional files
paths = {
"config_path": get_file_fn("config.json", required=True),
"config_path": get_file_fn("llm_config.json", required=False) or get_file_fn("config.json", required=True),
"tokenizer_config_json": get_file_fn("tokenizer_config.json", required=True),
"generation_config_json": get_file_fn("generation_config.json", required=False),
"tokenizer_model": get_file_fn("tokenizer.model", required=False)
Expand All @@ -126,8 +128,11 @@ def download_file_from_hub(file_name, required=True):
"wmap_path": get_file_fn("model.safetensors.index.json", required=False)
or get_file_fn("pytorch_model.bin.index.json", required=False),
"model_path": get_file_fn("model.safetensors", required=False)
or get_file_fn("pytorch_model.bin", required=False),
or get_file_fn("pytorch_model.bin", required=False)
or get_file_fn("ema.safetensors", required=False),
"special_tokens_json": get_file_fn("special_tokens_map.json", required=False),
"vision_config_path": get_file_fn("vit_config.json", required=False),
"ae_model_path": get_file_fn("ae.safetensors", required=False),
}

return cls(**paths, model_dir=args.model_dir, token=args.token)
Expand Down Expand Up @@ -158,6 +163,8 @@ def __getattr__(self, name):

@property
def arch(self):
if self.model_dir == "ByteDance-Seed/BAGEL-7B-MoT":
return "Bagel"
return self.config["architectures"][0]

@property
Expand Down Expand Up @@ -270,9 +277,11 @@ def build_config_dict(hf):
arch = hf.arch
print("Architecture: ", arch)

vision_config = config.get("vision_config", None)
other_config = config # save what is not text/vision for later use
config = config.get("text_config", config)
vision_config = getattr(hf, "vision_config", None)
if vision_config is None:
vision_config = config.get("vision_config", None)
other_config = config # save what is not text/vision for later use
config = config.get("text_config", config)

model_config = {}
training_config = {}
Expand All @@ -289,6 +298,7 @@ def build_config_dict(hf):
"transformer_ff_moe": config.get("moe_intermediate_size", None),
"mlp_activation_fn": ACT_TABLE[arch],
"layer_norm": LN_TABLE[arch],
# TODO: this can break encoder (e.g. bagel)
"heads_kv": config.get("multi_query", False)
or config.get(
"num_key_value_heads",
Expand Down Expand Up @@ -351,6 +361,24 @@ def build_config_dict(hf):
model_config["projector_activation_fn"] = other_config.get("projector_hidden_act", "gelu")
model_config["spatial_merge_size"] = other_config.get("spatial_merge_size", None)

if arch == "Bagel":
model_config["encoder"] = {
"hidden_size": vision_config.get("hidden_size", 1152),
"image_size": 1024, # 980 for VIT (vit_config.json), 1024 for VAE
"patch_size": vision_config["patch_size"],
"heads": vision_config["num_attention_heads"],
"heads_kv": vision_config["num_attention_heads"],
"layers": 26, # 27 in config, but actually 26 in safetensors
"transformer_ff": vision_config["intermediate_size"],
# siglip style learned position embeddings (like gemma3)
"position_encoding_type": PositionEncodingType.Learned,
"n_positions": (vision_config["image_size"] // vision_config["patch_size"]) ** 2,
"image_token_id": 151654,
"image_start_token_id": 151652,
"image_end_token_id": 151653,
"max_patches_per_side": 70,
}

if arch == "Gemma3ForConditionalGeneration":
if model_config.get("head_dim", None) is None:
model_config["head_dim"] = 256 # src/transformers/models/gemma3/configuration_gemma3.py#L61
Expand Down Expand Up @@ -646,6 +674,13 @@ def build_shards(model_config, hf, args, params):
eole_safetensor = {}

def build_first_shard(hf, eole_safetensor):
# let's add AE here (visual autoencoder for image generation)
if hf.ae_model_path is not None:
ae_checkpoint = hf.get_load_ckpt(*os.path.split(hf.ae_model_path))
ae_params = safetensors.torch.load_file(ae_checkpoint)
for key, value in ae_params.items():
eole_safetensor[f"image_autoencoder.{key}"] = value

for target in KEY_MAPS[hf.arch].keys():
if model_config["share_decoder_embeddings"] and target == "generator.weight":
continue
Expand Down
34 changes: 33 additions & 1 deletion eole/config/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,40 @@ class DecodingConfig(Config):
align_debug: bool = Field(default=False, description="Print best align for each word.")


class ImageGenerationConfig(Config):
"""
Let's centralize image generation related stuff here.
This is not a complete config, but rather a subset of options
that are relevant for image generation tasks.
Used as mixin for InferenceConfig for now, but might be properly nested at some point.
"""

# image generation specific stuff, might move elsewhere
image_generation: bool | None = Field(
default=False,
description="Generate image from text input. "
"This will only work if the model is trained for image generation.",
)
image_width: int | None = Field(
default=1024,
description="Width of the generated image. "
"This will only work if the model is trained for image generation.",
)
image_height: int | None = Field(
default=1024,
description="Height of the generated image. "
"This will only work if the model is trained for image generation.",
)
cfg_text_scale: float | None = Field(default=1.0, description="Classifier-free guidance scale for text input. ")
cfg_image_scale: float | None = Field(default=1.0, description="Classifier-free guidance scale for image input. ")
cfg_interval_min: float | None = Field(default=0.0, description="Minimum classifier-free guidance interval. ")
cfg_interval_max: float | None = Field(default=1.0, description="Maximum classifier-free guidance interval. ")
timestep_shift: float | None = Field(default=1.0, description="Shift the timestep for image generation. ")
num_timesteps: int | None = Field(default=50, description="Number of timesteps for image generation. ")


# in legacy opts, decoding config is separated (probably to be used elsewhere)
class InferenceConfig(RunningConfig, DecodingConfig, LoRaConfig, QuantizeConfig):
class InferenceConfig(RunningConfig, DecodingConfig, LoRaConfig, QuantizeConfig, ImageGenerationConfig):

model_config = get_config_dict()
model_config["arbitrary_types_allowed"] = True # to allow torch.dtype
Expand Down
11 changes: 11 additions & 0 deletions eole/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,17 @@ class VisionEncoderConfig(TransformerConfig, EncoderConfig):
num_channels: int | None = 3
image_size: int | None = 1024
patch_size: int | None = 16
max_patches_per_side: int | None = None
max_latent_size: int | None = 64 # bagel
latent_patch_size: int | None = 2
latent_channel: int | None = 16
image_token_id: int | None = 10 # pixtral uses 10, gemma3 uses 262144
image_start_token_id: int | None = None
image_end_token_id: int | None = None
mm_tokens_per_image: int | None = 256 # added for gemma3
layernorm_pre: bool = True # True for pixtral/mistral False for gemma3
patch_conv_bias: bool = False # False for pixtral/mistral True for gemma3
patch_conv_linear: bool = False # False for pixtral/gemma3 True for bagel


# use Field with default= + description would be more readable
Expand Down Expand Up @@ -771,6 +778,10 @@ class VisionTransformerLMModelConfig(TransformerConfig, BaseModelConfig):

adapter: str | None = Field(default="llava", description="Adapter type to use in the model.")

vit_position_embeddings: bool = Field(
default=False, description="Additional position embeddings for images, introduced for Bagel."
)

@model_validator(mode="before")
@classmethod
def encoder_decoder_type(cls, data: Any) -> Any:
Expand Down
4 changes: 4 additions & 0 deletions eole/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ def __init__(self, attentional=True):
# Decoder state
self.state = {}

@property
def device(self):
return next(self.parameters()).device

@classmethod
def from_config(cls, decoder_config, running_config=None, with_cross_attn=False):
"""Alternate constructor.
Expand Down
Loading