@@ -51,9 +51,10 @@ class TransformerBridge(nn.Module):
5151
5252 # Top-level hook aliases for legacy TransformerLens names
5353 # Placing these on the main bridge ensures aliases like 'hook_embed' are available
54- hook_aliases = {
54+ hook_aliases : Dict [ str , Union [ str , List [ str ]]] = {
5555 "hook_embed" : "embed.hook_out" ,
56- "hook_pos_embed" : "pos_embed.hook_out" ,
56+ # rotary style models use rotary_emb.hook_out, but gpt2-style models use pos_embed.hook_out
57+ "hook_pos_embed" : ["pos_embed.hook_out" , "rotary_emb.hook_out" ],
5758 "hook_unembed" : "unembed.hook_out" ,
5859 }
5960
@@ -131,11 +132,25 @@ def _initialize_hook_registry(self) -> None:
131132
132133 # Add bridge aliases if compatibility mode is enabled
133134 if self .compatibility_mode and self .hook_aliases :
134- for alias_name , target_name in self .hook_aliases .items ():
135+ for alias_name , target in self .hook_aliases .items ():
135136 # Use the existing alias system to resolve the target hook
136- target_hook = resolve_alias (self , alias_name , self .hook_aliases )
137- if target_hook is not None :
138- self ._hook_registry [alias_name ] = target_hook
137+ # Convert to Dict[str, str] for resolve_alias if target_name is a list
138+ if isinstance (target , list ):
139+ # For list targets, try each one until one works
140+ for single_target in target :
141+ try :
142+ target_hook = resolve_alias (
143+ self , alias_name , {alias_name : single_target }
144+ )
145+ if target_hook is not None :
146+ self ._hook_registry [alias_name ] = target_hook
147+ break
148+ except AttributeError :
149+ continue
150+ else :
151+ target_hook = resolve_alias (self , alias_name , {alias_name : target })
152+ if target_hook is not None :
153+ self ._hook_registry [alias_name ] = target_hook
139154
140155 self ._hook_registry_initialized = True
141156
@@ -213,9 +228,17 @@ def hook_dict(self) -> dict[str, HookPoint]:
213228
214229 # Add aliases if compatibility mode is enabled
215230 if self .compatibility_mode :
216- for alias_name , target_name in self .hook_aliases .items ():
217- if target_name in hooks :
218- hooks [alias_name ] = hooks [target_name ]
231+ for alias_name , target in self .hook_aliases .items ():
232+ # Handle both string and list target names
233+ if isinstance (target , list ):
234+ # For list targets, find the first one that exists in hooks
235+ for single_target in target :
236+ if single_target in hooks :
237+ hooks [alias_name ] = hooks [single_target ]
238+ break
239+ else :
240+ if target in hooks :
241+ hooks [alias_name ] = hooks [target ]
219242
220243 return hooks
221244
@@ -239,9 +262,16 @@ def __getattr__(self, name: str) -> Any:
239262
240263 # Check if this is a hook alias when compatibility mode is enabled
241264 if self .compatibility_mode and name in self .hook_aliases :
242- target_name = self .hook_aliases [name ]
243- if target_name in self ._hook_registry :
244- return self ._hook_registry [target_name ]
265+ target = self .hook_aliases [name ]
266+ # Handle both string and list target names
267+ if isinstance (target , list ):
268+ # For list targets, find the first one that exists in the registry
269+ for single_target in target :
270+ if single_target in self ._hook_registry :
271+ return self ._hook_registry [single_target ]
272+ else :
273+ if target in self ._hook_registry :
274+ return self ._hook_registry [target ]
245275
246276 return super ().__getattr__ (name )
247277
@@ -1040,7 +1070,15 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
10401070 # If compatibility mode is enabled, we need to handle aliases
10411071 # Create duplicate cache entries for TransformerLens compatibility
10421072 # Use the aliases collected from components (reverse mapping: new -> old)
1043- reverse_aliases = {new_name : old_name for old_name , new_name in aliases .items ()}
1073+ # Handle the case where some alias values might be lists
1074+ reverse_aliases = {}
1075+ for old_name , new_name in aliases .items ():
1076+ if isinstance (new_name , list ):
1077+ # For list values, create a mapping for each item in the list
1078+ for single_new_name in new_name :
1079+ reverse_aliases [single_new_name ] = old_name
1080+ else :
1081+ reverse_aliases [new_name ] = old_name
10441082
10451083 # Create duplicate entries in cache
10461084 cache_items_to_add = {}
@@ -1056,8 +1094,16 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
10561094
10571095 # Add cache entries for all aliases (both hook and cache aliases)
10581096 for alias_name , target_name in aliases .items ():
1059- if target_name in cache and alias_name not in cache :
1060- cache [alias_name ] = cache [target_name ]
1097+ # Handle both string and list target names
1098+ if isinstance (target_name , list ):
1099+ # For list targets, find the first one that exists in cache
1100+ for single_target in target_name :
1101+ if single_target in cache and alias_name not in cache :
1102+ cache [alias_name ] = cache [single_target ]
1103+ break
1104+ else :
1105+ if target_name in cache and alias_name not in cache :
1106+ cache [alias_name ] = cache [target_name ]
10611107
10621108 if return_cache_object :
10631109 cache_obj = ActivationCache (cache , self , has_batch_dim = not remove_batch_dim )
0 commit comments