File tree Expand file tree Collapse file tree 1 file changed +8
-6
lines changed
transformer_lens/model_bridge/sources Expand file tree Collapse file tree 1 file changed +8
-6
lines changed Original file line number Diff line number Diff line change 1818from transformer_lens .config import TransformerBridgeConfig
1919from transformer_lens .model_bridge .bridge import TransformerBridge
2020from transformer_lens .supported_models import MODEL_ALIASES
21- from transformer_lens .utils import get_tokenizer_with_bos
21+ from transformer_lens .utils import get_device , get_tokenizer_with_bos
2222
2323
2424def map_default_transformer_lens_config (hf_config ):
@@ -210,9 +210,12 @@ def boot(
210210
211211 adapter = ArchitectureAdapterFactory .select_architecture_adapter (bridge_config )
212212
213+ # No device specified by user, use the best available device for the current system
214+ if device is None :
215+ device = get_device ()
216+
213217 # Add device information to the config
214- if device is not None :
215- adapter .cfg .device = str (device )
218+ adapter .cfg .device = str (device )
216219
217220 # Load the model from HuggingFace using the original config
218221 hf_model = AutoModelForCausalLM .from_pretrained (
@@ -221,9 +224,8 @@ def boot(
221224 torch_dtype = dtype ,
222225 )
223226
224- # Move model to device if specified
225- if device is not None :
226- hf_model = hf_model .to (device )
227+ # Move model to device
228+ hf_model = hf_model .to (device )
227229
228230 # Load the tokenizer
229231 tokenizer = tokenizer
You can’t perform that action at this time.
0 commit comments