@@ -119,7 +119,7 @@ def download_file_from_hub(file_name, required=True):
119119
120120 # Fetch required and optional files
121121 paths = {
122- "config_path" : get_file_fn ("llm_config.json" , required = False ), # hard patch for bagel
122+ "config_path" : get_file_fn ("llm_config.json" , required = False ) or get_file_fn ( "config.json" , required = True ),
123123 "tokenizer_config_json" : get_file_fn ("tokenizer_config.json" , required = True ),
124124 "generation_config_json" : get_file_fn ("generation_config.json" , required = False ),
125125 "tokenizer_model" : get_file_fn ("tokenizer.model" , required = False )
@@ -128,7 +128,8 @@ def download_file_from_hub(file_name, required=True):
128128 "wmap_path" : get_file_fn ("model.safetensors.index.json" , required = False )
129129 or get_file_fn ("pytorch_model.bin.index.json" , required = False ),
130130 "model_path" : get_file_fn ("model.safetensors" , required = False )
131- or get_file_fn ("pytorch_model.bin" , required = False ) or get_file_fn ("ema.safetensors" , required = False ),
131+ or get_file_fn ("pytorch_model.bin" , required = False )
132+ or get_file_fn ("ema.safetensors" , required = False ),
132133 "special_tokens_json" : get_file_fn ("special_tokens_map.json" , required = False ),
133134 "vision_config_path" : get_file_fn ("vit_config.json" , required = False ),
134135 "ae_model_path" : get_file_fn ("ae.safetensors" , required = False ),
@@ -162,6 +163,8 @@ def __getattr__(self, name):
162163
163164 @property
164165 def arch (self ):
166+ if self .model_dir == "ByteDance-Seed/BAGEL-7B-MoT" :
167+ return "Bagel"
165168 return self .config ["architectures" ][0 ]
166169
167170 @property
@@ -280,8 +283,6 @@ def build_config_dict(hf):
280283 other_config = config # save what is not text/vision for later use
281284 config = config .get ("text_config" , config )
282285
283- print ("VISION_CONFIG:" , vision_config )
284-
285286 model_config = {}
286287 training_config = {}
287288
@@ -360,28 +361,22 @@ def build_config_dict(hf):
360361 model_config ["projector_activation_fn" ] = other_config .get ("projector_hidden_act" , "gelu" )
361362 model_config ["spatial_merge_size" ] = other_config .get ("spatial_merge_size" , None )
362363
363- if arch == "Qwen2ForCausalLM" :
364- model_config ["adapter" ] = "bagel"
364+ if arch == "Bagel" :
365365 model_config ["encoder" ] = {
366- "mlp_activation_fn" : "gelu-tanh" , # no up_proj it seems
367366 "hidden_size" : vision_config .get ("hidden_size" , 1152 ),
368- # "image_size": vision_config["image_size"],
369- "image_size" : 1024 ,
367+ "image_size" : 1024 , # 980 for VIT (vit_config.json), 1024 for VAE
370368 "patch_size" : vision_config ["patch_size" ],
371369 "heads" : vision_config ["num_attention_heads" ],
372370 "heads_kv" : vision_config ["num_attention_heads" ],
373- "layers" : 26 , # 27 in config, but actually 26 in safetensors...
371+ "layers" : 26 , # 27 in config, but actually 26 in safetensors
374372 "transformer_ff" : vision_config ["intermediate_size" ],
375373 # siglip style learned position embeddings (like gemma3)
376374 "position_encoding_type" : PositionEncodingType .Learned ,
377375 "n_positions" : (vision_config ["image_size" ] // vision_config ["patch_size" ]) ** 2 ,
378- "add_ffnbias" : True ,
379- "add_final_linear_bias" : True ,
380- "add_qkvbias" : True ,
381- "layer_norm" : "standard" ,
382- "patch_conv_bias" : True ,
383- "layernorm_pre" : False , # implies post layernorm
384376 "image_token_id" : 151654 ,
377+ "image_start_token_id" : 151652 ,
378+ "image_end_token_id" : 151653 ,
379+ "max_patches_per_side" : 70 ,
385380 }
386381
387382 if arch == "Gemma3ForConditionalGeneration" :
@@ -679,7 +674,7 @@ def build_shards(model_config, hf, args, params):
679674 eole_safetensor = {}
680675
681676 def build_first_shard (hf , eole_safetensor ):
682- # let's add AE here
677+ # let's add AE here (visual autoencoder for image generation)
683678 if hf .ae_model_path is not None :
684679 ae_checkpoint = hf .get_load_ckpt (* os .path .split (hf .ae_model_path ))
685680 ae_params = safetensors .torch .load_file (ae_checkpoint )
0 commit comments