@@ -1181,87 +1181,64 @@ def _convert_aimv2(
11811181 return out_dict
11821182
11831183
1184- def _convert_beit3 (
1185- state_dict : Dict [str , torch .Tensor ],
1186- model : VisionTransformer ,
1187- ) -> Dict [str , torch .Tensor ]:
1188- """Convert BEiT3 weights to standard VisionTransformer format."""
1184+ def _convert_beit3 (state_dict : dict , model ):
1185+ """
1186+ Turn a BEiT-3 checkpoint into a standard VisionTransformer state-dict.
1187+ """
11891188 import re
1189+ state_dict = state_dict .get ("model" , state_dict ) # unwrap if needed
1190+
1191+ # Prune unused
1192+ for k in ("beit3.text_embed.weight" , "beit3.vision_embed.mask_token" ):
1193+ state_dict .pop (k , None )
1194+
1195+ # Key renaming rules
1196+ rules = [
1197+ (r"beit3\." , "" ),
1198+ (r"vision_embed\.cls_token" , "cls_token" ),
1199+ (r"vision_embed\." , "patch_embed." ),
1200+ (r"embed_positions\." , "pos_embed." ),
1201+ (r"encoder\." , "" ),
1202+ (r"layers\." , "blocks." ),
1203+ (r"ffn_layernorm\." , "norm." ), (r"ffn\." , "mlp." ),
1204+ (r"self_attn_layer_norm\." , "norm1." ), (r"self_attn\." , "attn." ),
1205+ (r"final_layer_norm\." , "norm2." ),
1206+ (r"inner_attn_ln" , "norm" ),
1207+ (r"out_proj" , "proj" ),
1208+ (r"\.A\." , "." ),
1209+ ]
11901210
1191- if 'model' in state_dict :
1192- state_dict = state_dict ['model' ]
1193-
1194- # Remove text and mask tokens (vision-only)
1195- state_dict .pop ('beit3.text_embed.weight' , None )
1196- state_dict .pop ('beit3.vision_embed.mask_token' , None )
1197-
1198- # First pass: Apply all key transformations except qkv fusion
1199- intermediate_dict = {}
1211+ # First pass, rename keys
1212+ tmp = {}
12001213 for k , v in state_dict .items ():
1201- # Skip B branch weights (use only A branch)
1202- if '.B.' in k :
1203- continue
1204-
1205- # Apply all BEiT3 key transformations in one go
1206- if 'vision_embed.cls_token' in k :
1207- k = 'cls_token'
1214+ if ".B." in k :
1215+ continue # use branch-A only
1216+ for old , new in rules :
1217+ k = re . sub ( old , new , k )
1218+ if k == "pos_embed.weight" :
1219+ # strip first two positions, [1, N+1, D]
1220+ tmp [ "pos_embed" ] = v [ 2 :]. unsqueeze ( 0 )
12081221 else :
1209- k = k .replace ('beit3.' , '' )
1210- k = k .replace ('embed_positions.' , 'pos_embed.' )
1211- k = k .replace ('vision_embed.' , 'patch_embed.' )
1212- k = k .replace ('encoder.' , '' )
1213- k = k .replace ('layers.' , 'blocks.' )
1214- k = k .replace ('ffn.' , 'mlp.' )
1215- k = k .replace ('ffn_layernorm.' , 'norm.' )
1216- k = k .replace ('self_attn.' , 'attn.' )
1217- k = k .replace ('self_attn_layer_norm.' , 'norm1.' )
1218- k = k .replace ('final_layer_norm.' , 'norm2.' )
1219- k = k .replace ('inner_attn_ln' , 'norm' ) # Map inner attention LayerNorm to scale norm
1220- k = k .replace ('out_proj' , 'proj' ) # Map out_proj to proj
1221- k = k .replace ('A.' , '' ) # Remove A branch prefix
1222-
1223- # Handle positional embedding - skip first 2 positions (BEiT3 starts from index 2)
1224- if k == 'pos_embed.weight' :
1225- # BEiT3 pos_embed.weight has shape [num_patches + 3, embed_dim]
1226- # We want [1, num_patches + 1, embed_dim] for standard ViT (cls token + patches)
1227- intermediate_dict ['pos_embed' ] = v [2 :].unsqueeze (0 ) # Skip first 2 positions, add batch dim
1228- else :
1229- intermediate_dict [k ] = v
1222+ tmp [k ] = v
1223+
1224+ # Second pass, fuse q, k, v
1225+ out , buf = {}, {}
1226+ pat = re .compile (r"blocks\.(\d+)\.attn\.(q|k|v)_proj\.(weight|bias)$" )
1227+ for k , v in tmp .items ():
1228+ m = pat .fullmatch (k )
1229+ if not m : # anything not q/k/v -> copy through
1230+ out [k ] = v
1231+ continue
12301232
1231- # Second pass: Handle qkv fusion
1232- out_dict = {}
1233- processed_qkv = set ()
1234- for k , v in intermediate_dict .items ():
1235- # Handle attention projections - convert separate q,k,v to fused qkv
1236- if re .match (r"blocks\.(\d+)\.attn\.[qkv]_proj\.(weight|bias)" , k ):
1237- block_idx = re .search (r"blocks\.(\d+)" , k ).group (1 )
1238- param_type = re .search (r"\.(weight|bias)$" , k ).group (1 )
1239-
1240- # Only process once per block per parameter type
1241- block_param_key = f"{ block_idx } _{ param_type } "
1242- if block_param_key in processed_qkv :
1243- continue
1244-
1245- # Collect all three projections for this block
1246- q_key = f"blocks.{ block_idx } .attn.q_proj.{ param_type } "
1247- k_key = f"blocks.{ block_idx } .attn.k_proj.{ param_type } "
1248- v_key = f"blocks.{ block_idx } .attn.v_proj.{ param_type } "
1249-
1250- if all (key in intermediate_dict for key in [q_key , k_key , v_key ]):
1251- qkv_tensor = torch .cat ([
1252- intermediate_dict [q_key ],
1253- intermediate_dict [k_key ],
1254- intermediate_dict [v_key ]
1255- ], dim = 0 )
1256- out_dict [f"blocks.{ block_idx } .attn.qkv.{ param_type } " ] = qkv_tensor
1257- processed_qkv .add (block_param_key )
1258- continue
1259- else :
1260- assert False
1261- else :
1262- out_dict [k ] = v
1233+ blk , which , kind = m .groups () # block idx, 'q'/'k'/'v', 'weight'/'bias'
1234+ stash = buf .setdefault ((blk , kind ), {}) # Gather by block & param type
1235+ stash [which ] = v
1236+ if len (stash ) == 3 : # Have q, k, v -> concatenate
1237+ out [f"blocks.{ blk } .attn.qkv.{ kind } " ] = torch .cat (
1238+ [stash ['q' ], stash ['k' ], stash ['v' ]], dim = 0
1239+ )
12631240
1264- return out_dict
1241+ return out
12651242
12661243
12671244def checkpoint_filter_fn (
0 commit comments