Skip to content

Commit 2ca94a6

Browse files
committed
Compact _covert_beit3 fn
1 parent 38c5f3b commit 2ca94a6

File tree

1 file changed

+52
-75
lines changed

1 file changed

+52
-75
lines changed

timm/models/vision_transformer.py

Lines changed: 52 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,87 +1181,64 @@ def _convert_aimv2(
11811181
return out_dict
11821182

11831183

1184-
def _convert_beit3(
1185-
state_dict: Dict[str, torch.Tensor],
1186-
model: VisionTransformer,
1187-
) -> Dict[str, torch.Tensor]:
1188-
"""Convert BEiT3 weights to standard VisionTransformer format."""
1184+
def _convert_beit3(state_dict: dict, model):
1185+
"""
1186+
Turn a BEiT-3 checkpoint into a standard VisionTransformer state-dict.
1187+
"""
11891188
import re
1189+
state_dict = state_dict.get("model", state_dict) # unwrap if needed
1190+
1191+
# Prune unused
1192+
for k in ("beit3.text_embed.weight", "beit3.vision_embed.mask_token"):
1193+
state_dict.pop(k, None)
1194+
1195+
# Key renaming rules
1196+
rules = [
1197+
(r"beit3\.", ""),
1198+
(r"vision_embed\.cls_token", "cls_token"),
1199+
(r"vision_embed\.", "patch_embed."),
1200+
(r"embed_positions\.", "pos_embed."),
1201+
(r"encoder\.", ""),
1202+
(r"layers\.", "blocks."),
1203+
(r"ffn_layernorm\.", "norm."), (r"ffn\.", "mlp."),
1204+
(r"self_attn_layer_norm\.", "norm1."), (r"self_attn\.", "attn."),
1205+
(r"final_layer_norm\.", "norm2."),
1206+
(r"inner_attn_ln", "norm"),
1207+
(r"out_proj", "proj"),
1208+
(r"\.A\.", "."),
1209+
]
11901210

1191-
if 'model' in state_dict:
1192-
state_dict = state_dict['model']
1193-
1194-
# Remove text and mask tokens (vision-only)
1195-
state_dict.pop('beit3.text_embed.weight', None)
1196-
state_dict.pop('beit3.vision_embed.mask_token', None)
1197-
1198-
# First pass: Apply all key transformations except qkv fusion
1199-
intermediate_dict = {}
1211+
# First pass, rename keys
1212+
tmp = {}
12001213
for k, v in state_dict.items():
1201-
# Skip B branch weights (use only A branch)
1202-
if '.B.' in k:
1203-
continue
1204-
1205-
# Apply all BEiT3 key transformations in one go
1206-
if 'vision_embed.cls_token' in k:
1207-
k = 'cls_token'
1214+
if ".B." in k:
1215+
continue # use branch-A only
1216+
for old, new in rules:
1217+
k = re.sub(old, new, k)
1218+
if k == "pos_embed.weight":
1219+
# strip first two positions, [1, N+1, D]
1220+
tmp["pos_embed"] = v[2:].unsqueeze(0)
12081221
else:
1209-
k = k.replace('beit3.', '')
1210-
k = k.replace('embed_positions.', 'pos_embed.')
1211-
k = k.replace('vision_embed.', 'patch_embed.')
1212-
k = k.replace('encoder.', '')
1213-
k = k.replace('layers.', 'blocks.')
1214-
k = k.replace('ffn.', 'mlp.')
1215-
k = k.replace('ffn_layernorm.', 'norm.')
1216-
k = k.replace('self_attn.', 'attn.')
1217-
k = k.replace('self_attn_layer_norm.', 'norm1.')
1218-
k = k.replace('final_layer_norm.', 'norm2.')
1219-
k = k.replace('inner_attn_ln', 'norm') # Map inner attention LayerNorm to scale norm
1220-
k = k.replace('out_proj', 'proj') # Map out_proj to proj
1221-
k = k.replace('A.', '') # Remove A branch prefix
1222-
1223-
# Handle positional embedding - skip first 2 positions (BEiT3 starts from index 2)
1224-
if k == 'pos_embed.weight':
1225-
# BEiT3 pos_embed.weight has shape [num_patches + 3, embed_dim]
1226-
# We want [1, num_patches + 1, embed_dim] for standard ViT (cls token + patches)
1227-
intermediate_dict['pos_embed'] = v[2:].unsqueeze(0) # Skip first 2 positions, add batch dim
1228-
else:
1229-
intermediate_dict[k] = v
1222+
tmp[k] = v
1223+
1224+
# Second pass, fuse q, k, v
1225+
out, buf = {}, {}
1226+
pat = re.compile(r"blocks\.(\d+)\.attn\.(q|k|v)_proj\.(weight|bias)$")
1227+
for k, v in tmp.items():
1228+
m = pat.fullmatch(k)
1229+
if not m: # anything not q/k/v -> copy through
1230+
out[k] = v
1231+
continue
12301232

1231-
# Second pass: Handle qkv fusion
1232-
out_dict = {}
1233-
processed_qkv = set()
1234-
for k, v in intermediate_dict.items():
1235-
# Handle attention projections - convert separate q,k,v to fused qkv
1236-
if re.match(r"blocks\.(\d+)\.attn\.[qkv]_proj\.(weight|bias)", k):
1237-
block_idx = re.search(r"blocks\.(\d+)", k).group(1)
1238-
param_type = re.search(r"\.(weight|bias)$", k).group(1)
1239-
1240-
# Only process once per block per parameter type
1241-
block_param_key = f"{block_idx}_{param_type}"
1242-
if block_param_key in processed_qkv:
1243-
continue
1244-
1245-
# Collect all three projections for this block
1246-
q_key = f"blocks.{block_idx}.attn.q_proj.{param_type}"
1247-
k_key = f"blocks.{block_idx}.attn.k_proj.{param_type}"
1248-
v_key = f"blocks.{block_idx}.attn.v_proj.{param_type}"
1249-
1250-
if all(key in intermediate_dict for key in [q_key, k_key, v_key]):
1251-
qkv_tensor = torch.cat([
1252-
intermediate_dict[q_key],
1253-
intermediate_dict[k_key],
1254-
intermediate_dict[v_key]
1255-
], dim=0)
1256-
out_dict[f"blocks.{block_idx}.attn.qkv.{param_type}"] = qkv_tensor
1257-
processed_qkv.add(block_param_key)
1258-
continue
1259-
else:
1260-
assert False
1261-
else:
1262-
out_dict[k] = v
1233+
blk, which, kind = m.groups() # block idx, 'q'/'k'/'v', 'weight'/'bias'
1234+
stash = buf.setdefault((blk, kind), {}) # Gather by block & param type
1235+
stash[which] = v
1236+
if len(stash) == 3: # Have q, k, v -> concatenate
1237+
out[f"blocks.{blk}.attn.qkv.{kind}"] = torch.cat(
1238+
[stash['q'], stash['k'], stash['v']], dim=0
1239+
)
12631240

1264-
return out_dict
1241+
return out
12651242

12661243

12671244
def checkpoint_filter_fn(

0 commit comments

Comments
 (0)