Skip to content

Commit 7c3ff2e

Browse files
committed
Merge remote-tracking branch 'origin/dev-3.x' into add_support_for_gpt_oss
2 parents 975bea7 + b80775f commit 7c3ff2e

File tree

8 files changed

+178
-29
lines changed

8 files changed

+178
-29
lines changed

tests/integration/model_bridge/test_bridge_root_module_cache_compatibility.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
},
1212
)
1313

14+
bridge.enable_compatibility_mode(disable_warnings=False)
15+
1416
act_names_in_cache = [
1517
# "hook_embed",
1618
# "hook_pos_embed",

tests/integration/model_bridge/test_cache_hook_equality.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
prompt = "Hello World!"
99
bridge = TransformerBridge.boot_transformers(MODEL, device="cpu")
10+
bridge.enable_compatibility_mode(disable_warnings=False)
1011
hooked_transformer = HookedTransformer.from_pretrained(MODEL, device="cpu")
1112

1213
act_names_in_cache = [

tests/integration/model_bridge/test_hook_compatibility.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def test_required_hooks_available(self, transformer_bridge):
7171
def hook_exists_on_model(model, hook_path: str) -> bool:
7272
"""Check if a hook path exists on the model by traversing attributes."""
7373
parts = hook_path.split(".")
74+
model.enable_compatibility_mode(disable_warnings=False)
7475
current = model
7576

7677
try:

tests/unit/utilities/test_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def test_resolve_existing_alias(self):
2020
mock_target = Mock()
2121
mock_hook = Mock()
2222
mock_target.actual_hook = mock_hook
23+
mock_target.disable_warnings = False # Ensure warnings are enabled
2324

2425
hook_aliases = {"old_hook": "actual_hook"}
2526

transformer_lens/model_bridge/bridge.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(self, model: nn.Module, adapter: ArchitectureAdapter, tokenizer: An
6565
self.adapter = adapter
6666
self.cfg = adapter.cfg
6767
self.tokenizer = tokenizer
68+
self.compatibility_mode = False
6869

6970
# Add device information to config from the loaded model
7071
if not hasattr(self.cfg, "device"):
@@ -196,6 +197,29 @@ def __str__(self) -> str:
196197
lines.extend(self._format_component_mapping(mapping, indent=1))
197198
return "\n".join(lines)
198199

200+
def enable_compatibility_mode(self, disable_warnings: bool = False) -> None:
201+
"""Enable compatibility mode for the bridge.
202+
203+
This sets up the bridge to work with legacy HookedTransformer components/hooks.
204+
It will also disable warnings about the usage of legacy components/hooks if specified.
205+
206+
Args:
207+
disable_warnings: Whether to disable warnings about legacy components/hooks
208+
"""
209+
# Avoid circular import
210+
from transformer_lens.utilities.bridge_components import (
211+
apply_fn_to_all_components,
212+
)
213+
214+
self.compatibility_mode = True
215+
216+
def set_compatibility_mode(component: Any) -> None:
217+
"""Set compatibility mode on a component."""
218+
component.compatibility_mode = True
219+
component.disable_warnings = disable_warnings
220+
221+
apply_fn_to_all_components(self, set_compatibility_mode)
222+
199223
# ==================== TOKENIZATION METHODS ====================
200224

201225
def to_tokens(
@@ -807,26 +831,28 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
807831
for hp, _ in hooks:
808832
hp.remove_hooks()
809833

810-
# Create duplicate cache entries for TransformerLens compatibility
811-
# Use the aliases collected from components (reverse mapping: new -> old)
812-
reverse_aliases = {new_name: old_name for old_name, new_name in aliases.items()}
813-
814-
# Create duplicate entries in cache
815-
cache_items_to_add = {}
816-
for cache_name, cached_value in cache.items():
817-
# Check if this cache name should have an alias
818-
for new_name, old_name in reverse_aliases.items():
819-
if cache_name == new_name:
820-
cache_items_to_add[old_name] = cached_value
821-
break
822-
823-
# Add the aliased entries to the cache
824-
cache.update(cache_items_to_add)
825-
826-
# Add cache entries for all aliases (both hook and cache aliases)
827-
for alias_name, target_name in aliases.items():
828-
if target_name in cache and alias_name not in cache:
829-
cache[alias_name] = cache[target_name]
834+
if self.compatibility_mode == True:
835+
# If compatibility mode is enabled, we need to handle aliases
836+
# Create duplicate cache entries for TransformerLens compatibility
837+
# Use the aliases collected from components (reverse mapping: new -> old)
838+
reverse_aliases = {new_name: old_name for old_name, new_name in aliases.items()}
839+
840+
# Create duplicate entries in cache
841+
cache_items_to_add = {}
842+
for cache_name, cached_value in cache.items():
843+
# Check if this cache name should have an alias
844+
for new_name, old_name in reverse_aliases.items():
845+
if cache_name == new_name:
846+
cache_items_to_add[old_name] = cached_value
847+
break
848+
849+
# Add the aliased entries to the cache
850+
cache.update(cache_items_to_add)
851+
852+
# Add cache entries for all aliases (both hook and cache aliases)
853+
for alias_name, target_name in aliases.items():
854+
if target_name in cache and alias_name not in cache:
855+
cache[alias_name] = cache[target_name]
830856

831857
if return_cache_object:
832858
cache_obj = ActivationCache(cache, self, has_batch_dim=not remove_batch_dim)

transformer_lens/model_bridge/generalized_components/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ class GeneralizedComponent(nn.Module):
2727
# Class attribute indicating whether this component represents a list item (like blocks)
2828
is_list_item: bool = False
2929

30+
# Compatibility mode that can be activated/deactivated for legacy components/hooks
31+
compatibility_mode: bool = False
32+
# Whether to disable warnings about deprecated hooks
33+
disable_warnings: bool = False
34+
3035
# Dictionary mapping deprecated hook names to their new equivalents
3136
# Subclasses can override this to define their own aliases
3237
hook_aliases: Dict[str, str] = {}
@@ -175,7 +180,7 @@ def _getattr_helper(self, name: str) -> Any:
175180

176181
# Check if this is a deprecated hook alias
177182
resolved_hook = resolve_alias(self, name, self.hook_aliases)
178-
if resolved_hook is not None:
183+
if resolved_hook is not None and self.compatibility_mode == True:
179184
return resolved_hook
180185

181186
# Avoid recursion by checking if we're looking for original_component
@@ -248,7 +253,7 @@ def __getattr__(self, name: str):
248253
# If we reach here, we can resolve the alias normally
249254
resolved_property = resolve_alias(self, name, self.property_aliases)
250255

251-
if resolved_property is not None:
256+
if resolved_property is not None and self.compatibility_mode == True:
252257
return resolved_property
253258

254259
# If an internal call or no alias was found, just regularly get the attribute

transformer_lens/utilities/aliases.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66

77
def resolve_alias(
8-
target_object: Any, requested_name: str, aliases: Dict[str, str]
8+
target_object: Any,
9+
requested_name: str,
10+
aliases: Dict[str, str],
911
) -> Optional[Any]:
1012
"""Resolve a hook alias to the actual hook object.
1113
@@ -19,12 +21,14 @@ def resolve_alias(
1921
"""
2022
if requested_name in aliases:
2123
target_name = aliases[requested_name]
22-
warnings.warn(
23-
f"Hook '{requested_name}' is deprecated and will be removed in a future version. "
24-
f"Use '{target_name}' instead.",
25-
FutureWarning,
26-
stacklevel=3, # Adjusted for utility function call
27-
)
24+
25+
if target_object.disable_warnings == False:
26+
warnings.warn(
27+
f"Hook '{requested_name}' is deprecated and will be removed in a future version. "
28+
f"Use '{target_name}' instead.",
29+
FutureWarning,
30+
stacklevel=3, # Adjusted for utility function call
31+
)
2832

2933
target_name_split = target_name.split(".")
3034

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""Utilities for traversing and applying functions to every component in a TransformerBridge model."""
2+
3+
from typing import Any, Callable
4+
5+
import torch.nn as nn
6+
7+
from transformer_lens.model_bridge.bridge import TransformerBridge
8+
from transformer_lens.model_bridge.generalized_components.base import (
9+
GeneralizedComponent,
10+
)
11+
12+
13+
def collect_all_submodules_of_component(
14+
model: TransformerBridge,
15+
component: GeneralizedComponent,
16+
submodules: dict,
17+
block_prefix: str = "",
18+
) -> dict:
19+
"""Recursively collects all submodules of a component in a TransformerBridge model.
20+
Args:
21+
model: The TransformerBridge model to collect submodules from
22+
component: The component to collect submodules from
23+
submodules: A dictionary to populate with submodules (modified in-place)
24+
block_prefix: Prefix for the block name, needed for components that are part of a block bridge
25+
Returns:
26+
Dictionary mapping submodule names to their respective submodules
27+
"""
28+
for component_submodule in component.submodules.values():
29+
submodules[block_prefix + component_submodule.name] = component_submodule
30+
31+
# If the component is a list item, we need to collect all submodules of the block bridge
32+
if component_submodule.is_list_item:
33+
submodules = collect_components_of_block_bridge(model, component_submodule, submodules)
34+
35+
# If the component has submodules, we need to collect them recursively
36+
if component_submodule.submodules:
37+
submodules = collect_all_submodules_of_component(
38+
model, component_submodule, submodules, block_prefix
39+
)
40+
return submodules
41+
42+
43+
def collect_components_of_block_bridge(
44+
model: TransformerBridge, component: GeneralizedComponent, components: dict
45+
) -> dict:
46+
"""Collects all components of a BlockBridge component.
47+
Args:
48+
model: The TransformerBridge model to collect components from
49+
component: The BlockBridge component to collect components from
50+
components: A dictionary to populate with components (modified in-place)
51+
Returns:
52+
Dictionary mapping component names to their respective components
53+
"""
54+
55+
# Retrieve the remote component list from the adapter (we need a ModuleList to iterate over)
56+
remote_module_list = model.adapter.get_remote_component(model.original_model, component.name)
57+
58+
# Make sure the remote component is a ModuleList
59+
if isinstance(remote_module_list, nn.ModuleList):
60+
for block in remote_module_list:
61+
components[block.name] = block
62+
components = collect_all_submodules_of_component(model, block, components, block.name)
63+
return components
64+
65+
66+
def collect_all_components(model: TransformerBridge, components: dict) -> dict:
67+
"""Collects all components in a TransformerBridge inside a dictionary.
68+
The keys are the component names, and the values are the components themselves.
69+
Args:
70+
model: The TransformerBridge model to collect components from
71+
components: A dictionary to populate with components (modified in-place)
72+
Returns:
73+
Dictionary mapping component names to their respective components
74+
"""
75+
76+
# Iterate through all components in component mapping
77+
for component in model.adapter.get_component_mapping().values():
78+
components[component.name] = component
79+
components = collect_all_submodules_of_component(model, component, components)
80+
81+
# We need to enable compatibility mode for all different blocks of the component if the component is a list item
82+
if component.is_list_item:
83+
components = collect_components_of_block_bridge(model, component, components)
84+
return components
85+
86+
87+
def apply_fn_to_all_components(
88+
model: TransformerBridge,
89+
fn: Callable[[GeneralizedComponent], Any],
90+
components: dict | None = None,
91+
) -> dict[str, Any]:
92+
"""Applies a function to all components in the TransformerBridge model.
93+
Args:
94+
model: The TransformerBridge model to apply the function to
95+
fn: The function to apply to each component
96+
components: Optional dictionary of components to apply the function to, if None, all components are collected
97+
Returns:
98+
return_values: A dictionary mapping component names to the return values of the function
99+
"""
100+
101+
if components is None:
102+
components = collect_all_components(model, {})
103+
104+
return_values = {}
105+
106+
for component in components.values():
107+
return_values[component.name] = fn(component)
108+
109+
return return_values

0 commit comments

Comments
 (0)