Skip to content

Commit 6ca73ef

Browse files
Match device selection of TransformerBridge to HookedTransformer (#1047)
Co-authored-by: Bryce Meyer <[email protected]>
1 parent 0a58e98 commit 6ca73ef

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

transformer_lens/model_bridge/sources/transformers.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from transformer_lens.config import TransformerBridgeConfig
1919
from transformer_lens.model_bridge.bridge import TransformerBridge
2020
from transformer_lens.supported_models import MODEL_ALIASES
21-
from transformer_lens.utils import get_tokenizer_with_bos
21+
from transformer_lens.utils import get_device, get_tokenizer_with_bos
2222

2323

2424
def map_default_transformer_lens_config(hf_config):
@@ -210,9 +210,12 @@ def boot(
210210

211211
adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config)
212212

213+
# No device specified by user, use the best available device for the current system
214+
if device is None:
215+
device = get_device()
216+
213217
# Add device information to the config
214-
if device is not None:
215-
adapter.cfg.device = str(device)
218+
adapter.cfg.device = str(device)
216219

217220
# Load the model from HuggingFace using the original config
218221
hf_model = AutoModelForCausalLM.from_pretrained(
@@ -221,9 +224,8 @@ def boot(
221224
torch_dtype=dtype,
222225
)
223226

224-
# Move model to device if specified
225-
if device is not None:
226-
hf_model = hf_model.to(device)
227+
# Move model to device
228+
hf_model = hf_model.to(device)
227229

228230
# Load the tokenizer
229231
tokenizer = tokenizer

0 commit comments

Comments
 (0)