Skip to content

Commit 787d5c8

Browse files
authored
feat: Bridge.boot should allow using alias model names, but show a deprecation warning (#1028)
* Automatically replace aliased model name and show deprecation warning * add test for aliased model name and deprecation
1 parent ac73820 commit 787d5c8

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

tests/integration/model_bridge/test_bridge_integration.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
including model initialization, text generation, hooks, and caching.
55
"""
66

7+
import logging
8+
79
import pytest
810
import torch
911

@@ -21,6 +23,35 @@ def test_model_initialization():
2123
assert isinstance(bridge.original_model, torch.nn.Module), "Model should be a PyTorch module"
2224

2325

26+
def test_model_initialization_with_alias(caplog):
27+
"""Test that the model can be initialized correctly with an alias and logs deprecation warning."""
28+
29+
model_name = "gpt2-small"
30+
31+
# Set logging level to capture warnings
32+
with caplog.at_level(logging.WARNING):
33+
bridge = TransformerBridge.boot_transformers(model_name)
34+
35+
# Basic assertions
36+
assert bridge is not None, "Bridge should be initialized"
37+
assert bridge.tokenizer is not None, "Tokenizer should be initialized"
38+
assert isinstance(
39+
bridge.original_model, torch.nn.Module
40+
), "Model should be a PyTorch module"
41+
42+
# Check that a deprecation warning was logged
43+
deprecation_found = False
44+
for record in caplog.records:
45+
if "DEPRECATED" in record.message:
46+
deprecation_found = True
47+
# Verify the warning contains expected content
48+
assert "gpt2-small" in record.message, "Warning should mention the deprecated alias"
49+
assert "gpt2" in record.message, "Warning should mention the official name"
50+
break
51+
52+
assert deprecation_found, "Expected deprecation warning for alias 'gpt2-small' was not logged"
53+
54+
2455
def test_text_generation():
2556
"""Test basic text generation functionality."""
2657
model_name = "gpt2" # Use a smaller model for testing

transformer_lens/model_bridge/sources/transformers.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
This module provides functionality to load and convert models from HuggingFace to TransformerLens format.
44
"""
55

6-
76
import copy
7+
import logging
88
import os
99

1010
import torch
@@ -16,6 +16,7 @@
1616
)
1717

1818
from transformer_lens.model_bridge.bridge import TransformerBridge
19+
from transformer_lens.supported_models import MODEL_ALIASES
1920
from transformer_lens.utils import get_tokenizer_with_bos
2021

2122

@@ -111,6 +112,16 @@ def boot(
111112
ArchitectureAdapterFactory,
112113
)
113114

115+
# MODEL_ALIASES is a dict of {official_name: [alias1, alias2, ...]}
116+
# Check if model_name that the user passed is an alias, and if so, use the official name
117+
for official_name, aliases in MODEL_ALIASES.items():
118+
if model_name in aliases:
119+
logging.warning(
120+
f"DEPRECATED: You are using a deprecated, model_name alias '{model_name}'. TransformerLens will now load the official transformers model name, '{official_name}' instead.\n Please update your code to use the official name by changing model_name from '{model_name}' to '{official_name}'.\nSince TransformerLens v3, all model names should be the official transformers model names.\nThe aliases will be removed in the next version of TransformerLens, so please do the update now."
121+
)
122+
model_name = official_name
123+
break
124+
114125
hf_config = AutoConfig.from_pretrained(model_name, output_attentions=True)
115126

116127
# Apply config variables to hf_config before selecting adapter

0 commit comments

Comments
 (0)