Skip to content

Commit 7c8d9bb

Browse files
authored
created new base config class (#1042)
* created new base config class * cleaned up imports * reorganized config * setup transformer bridge config * fixed docstring issue * fixed typing issues * ran format * fixed docstring * fixed import issues * ran format * fixed type checking again * fixed import again * fixed doc string * removed import * seperated devices * cleaned up utils a bit * ran format * fixed import * fixed imports * ran format * fixed typing * updated typing * updated name * changed to python 3.12 * cleaned up comments * removed extra functions
1 parent 446b9d0 commit 7c8d9bb

File tree

98 files changed

+924
-500
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

98 files changed

+924
-500
lines changed

.github/workflows/checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ jobs:
9999
- name: Set up Python
100100
uses: actions/setup-python@v4
101101
with:
102-
python-version: "3.13"
102+
python-version: "3.12"
103103
cache: "poetry"
104104
- name: Cache Models used with Tests
105105
uses: actions/cache@v3

tests/acceptance/test_multi_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66
from transformer_lens.HookedTransformer import HookedTransformer
7-
from transformer_lens.utilities.devices import get_best_available_device
7+
from transformer_lens.utilities.multi_gpu import get_best_available_device
88

99

1010
@pytest.fixture

tests/integration/test_attention_mask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
22

33
from transformer_lens import utils
4+
from transformer_lens.config import HookedTransformerConfig
45
from transformer_lens.HookedTransformer import HookedTransformer
5-
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
66

77

88
def test_attention_mask():

tests/integration/test_grouped_query_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from transformer_lens import HookedTransformer
55
from transformer_lens.components import Attention, GroupedQueryAttention
6-
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
6+
from transformer_lens.config import HookedTransformerConfig
77

88

99
def test_grouped_query_attention_output_is_correct():

tests/mocks/architecture_adapter.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
import torch.nn as nn
66

7+
from transformer_lens.config import TransformerBridgeConfig
78
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
89
from transformer_lens.model_bridge.generalized_components import (
910
AttentionBridge,
@@ -19,8 +20,17 @@ class MockArchitectureAdapter(ArchitectureAdapter):
1920

2021
def __init__(self, cfg=None):
2122
if cfg is None:
22-
# Create a minimal config for testing
23-
cfg = SimpleNamespace(d_mlp=512, intermediate_size=512, default_prepend_bos=True)
23+
# Create a minimal TransformerBridgeConfig for testing
24+
cfg = TransformerBridgeConfig(
25+
d_model=512,
26+
d_head=64,
27+
n_layers=2,
28+
n_ctx=1024,
29+
d_vocab=1000,
30+
d_mlp=2048,
31+
default_prepend_bos=True,
32+
architecture="GPT2LMHeadModel", # Default test architecture
33+
)
2434
super().__init__(cfg)
2535
# Use actual bridge instances instead of tuples
2636
# Provide minimal config to components that require it

tests/unit/components/test_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from transformers.utils import is_bitsandbytes_available
66

77
from transformer_lens.components import Attention
8-
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
8+
from transformer_lens.config import HookedTransformerConfig
99
from transformer_lens.utilities.attention import complex_attn_linear
1010

1111
if is_bitsandbytes_available():

tests/unit/factories/test_activation_function_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import pytest
22
import torch
33

4+
from transformer_lens.config import HookedTransformerConfig
45
from transformer_lens.factories.activation_function_factory import (
56
ActivationFunctionFactory,
67
)
7-
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
88
from transformer_lens.utilities.activation_functions import SUPPORTED_ACTIVATIONS
99

1010

tests/unit/factories/test_mlp_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from transformer_lens.components.mlps.gated_mlp import GatedMLP
55
from transformer_lens.components.mlps.gated_mlp_4bit import GatedMLP4Bit
66
from transformer_lens.components.mlps.mlp import MLP
7+
from transformer_lens.config import HookedTransformerConfig
78
from transformer_lens.factories.mlp_factory import MLPFactory
8-
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
99

1010

1111
def test_create_mlp_basic():

tests/unit/pretrained_weight_conversions/test_neo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44

55
from transformer_lens import HookedTransformer
6-
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
6+
from transformer_lens.config import HookedTransformerConfig
77
from transformer_lens.pretrained.weight_conversions.neo import convert_neo_weights
88

99

tests/unit/test_architecture_adapter.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
mock_model_adapter,
1010
)
1111
from tests.mocks.models import MockGemma3Model
12+
from transformer_lens.config import TransformerBridgeConfig
1213
from transformer_lens.model_bridge.supported_architectures.gemma3 import (
1314
Gemma3ArchitectureAdapter,
1415
)
@@ -37,17 +38,20 @@ def test_get_remote_component_with_mock(
3738
assert isinstance(mlp, nn.Module)
3839

3940

40-
class DummyHFConfig:
41-
def __init__(self):
42-
self.num_attention_heads = 8
43-
self.num_key_value_heads = 8
44-
self.hidden_size = 128
45-
# Add any other attributes needed by the adapter here
46-
47-
4841
@pytest.fixture
4942
def cfg():
50-
return DummyHFConfig()
43+
return TransformerBridgeConfig(
44+
d_model=128,
45+
d_head=16, # 128 / 8 heads
46+
n_layers=2,
47+
n_ctx=1024,
48+
n_heads=8,
49+
d_vocab=1000,
50+
d_mlp=512,
51+
n_key_value_heads=8,
52+
default_prepend_bos=True,
53+
architecture="Gemma3ForCausalLM", # Test architecture
54+
)
5155

5256

5357
@pytest.fixture

0 commit comments

Comments
 (0)