Skip to content

Commit 887cce4

Browse files
Make transformer lazy import (#292)
* lazy import for transformer library.
1 parent 80a2d82 commit 887cce4

File tree

3 files changed

+99
-30
lines changed

3 files changed

+99
-30
lines changed

src/agentlab/llm/chat_api.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88

99
import anthropic
1010
import openai
11-
from huggingface_hub import InferenceClient
1211
from openai import NOT_GIVEN, OpenAI
1312

1413
import agentlab.llm.tracking as tracking
1514
from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs
16-
from agentlab.llm.huggingface_utils import HFBaseChatModel
1715
from agentlab.llm.llm_utils import AIMessage, Discussion
1816

1917

@@ -139,6 +137,8 @@ def make_model(self):
139137
self.model_url = os.environ["AGENTLAB_MODEL_URL"]
140138
if self.token is None:
141139
self.token = os.environ["AGENTLAB_MODEL_TOKEN"]
140+
# Lazy import to avoid importing HF utilities on non-HF paths
141+
from agentlab.llm.huggingface_utils import HuggingFaceURLChatModel
142142

143143
return HuggingFaceURLChatModel(
144144
model_name=self.model_name,
@@ -438,28 +438,26 @@ def __init__(
438438
)
439439

440440

441-
class HuggingFaceURLChatModel(HFBaseChatModel):
442-
def __init__(
443-
self,
444-
model_name: str,
445-
base_model_name: str,
446-
model_url: str,
447-
token: Optional[str] = None,
448-
temperature: Optional[int] = 1e-1,
449-
max_new_tokens: Optional[int] = 512,
450-
n_retry_server: Optional[int] = 4,
451-
log_probs: Optional[bool] = False,
452-
):
453-
super().__init__(model_name, base_model_name, n_retry_server, log_probs)
454-
if temperature < 1e-3:
455-
logging.warning("Models might behave weirdly when temperature is too low.")
456-
self.temperature = temperature
441+
def __getattr__(name: str):
442+
"""Lazy re-export of optional classes to keep imports light.
443+
444+
This lets users import HuggingFaceURLChatModel from agentlab.llm.chat_api
445+
without importing heavy dependencies unless actually used.
446+
447+
Args:
448+
name: The name of the attribute to retrieve.
449+
450+
Returns:
451+
The requested class or raises AttributeError if not found.
457452
458-
if token is None:
459-
token = os.environ["TGI_TOKEN"]
453+
Raises:
454+
AttributeError: If the requested attribute is not available.
455+
"""
456+
if name == "HuggingFaceURLChatModel":
457+
from agentlab.llm.huggingface_utils import HuggingFaceURLChatModel
460458

461-
client = InferenceClient(model=model_url, token=token)
462-
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens, details=log_probs)
459+
return HuggingFaceURLChatModel
460+
raise AttributeError(name)
463461

464462

465463
class VLLMChatModel(ChatModel):

src/agentlab/llm/huggingface_utils.py

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import logging
2+
import os
23
import time
4+
from functools import partial
35
from typing import Any, List, Optional, Union
46

57
from pydantic import Field
6-
from transformers import AutoTokenizer, GPT2TokenizerFast
78

89
from agentlab.llm.base_api import AbstractChatModel
910
from agentlab.llm.llm_utils import AIMessage, Discussion
@@ -45,6 +46,14 @@ def __init__(self, model_name, base_model_name, n_retry_server, log_probs):
4546
self.n_retry_server = n_retry_server
4647
self.log_probs = log_probs
4748

49+
# Lazy import to avoid heavy transformers import when unused
50+
try:
51+
from transformers import AutoTokenizer, GPT2TokenizerFast # type: ignore
52+
except Exception as e: # pragma: no cover - surfaced only when transformers missing
53+
raise ImportError(
54+
"The 'transformers' package is required for HuggingFace models. Install it to use HF backends."
55+
) from e
56+
4857
if base_model_name is None:
4958
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
5059
else:
@@ -60,7 +69,7 @@ def __call__(
6069
self,
6170
messages: list[dict],
6271
n_samples: int = 1,
63-
temperature: float = None,
72+
temperature: Optional[float] = None,
6473
) -> Union[AIMessage, List[AIMessage]]:
6574
"""
6675
Generate one or more responses for the given messages.
@@ -85,7 +94,7 @@ def __call__(
8594
except Exception as e:
8695
if "Conversation roles must alternate" in str(e):
8796
logging.warning(
88-
f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role. "
97+
"Failed to apply the chat template. Maybe because it doesn't support the 'system' role. "
8998
"Retrying with the 'system' role appended to the 'user' role."
9099
)
91100
messages = _prepend_system_to_first_user(messages)
@@ -100,7 +109,11 @@ def __call__(
100109
itr = 0
101110
while True:
102111
try:
103-
temperature = temperature if temperature is not None else self.temperature
112+
temperature = (
113+
temperature
114+
if temperature is not None
115+
else getattr(self, "temperature", 0.1)
116+
)
104117
answer = self.llm(prompt, temperature=temperature)
105118
response = AIMessage(answer)
106119
if self.log_probs:
@@ -144,9 +157,52 @@ def _prepend_system_to_first_user(messages, column_remap={}):
144157
for msg in messages:
145158
if msg[role_key] == human_key:
146159
# Prepend system content to the first user content
147-
msg[text_key] = system_content + "\n" + msg[text_key]
160+
msg[text_key] = str(system_content) + "\n" + str(msg[text_key])
148161
# Remove the original system message
149162
del messages[system_index]
150163
break # Ensures that only the first user message is modified
151164

152165
return messages
166+
167+
168+
class HuggingFaceURLChatModel(HFBaseChatModel):
169+
"""HF backend using a Text Generation Inference (TGI) HTTP endpoint.
170+
171+
This class is placed here to keep all heavy HF imports optional and only
172+
loaded when a HF backend is explicitly requested.
173+
"""
174+
175+
def __init__(
176+
self,
177+
model_name: str,
178+
model_url: str,
179+
base_model_name: Optional[str] = None,
180+
token: Optional[str] = None,
181+
temperature: Optional[float] = 1e-1,
182+
max_new_tokens: Optional[int] = 512,
183+
n_retry_server: Optional[int] = 4,
184+
log_probs: Optional[bool] = False,
185+
):
186+
super().__init__(model_name, base_model_name, n_retry_server, log_probs)
187+
if temperature is not None and temperature < 1e-3:
188+
logging.warning("Models might behave weirdly when temperature is too low.")
189+
self.temperature = temperature
190+
191+
if token is None:
192+
# support both env var names used elsewhere
193+
token = os.environ.get("TGI_TOKEN") or os.environ.get("AGENTLAB_MODEL_TOKEN")
194+
195+
# Lazy import huggingface_hub here to avoid import on non-HF paths
196+
try:
197+
from huggingface_hub import InferenceClient # type: ignore
198+
except Exception as e: # pragma: no cover - surfaced only when package missing
199+
raise ImportError(
200+
"The 'huggingface_hub' package is required for HuggingFace URL backends."
201+
) from e
202+
203+
client = InferenceClient(model=model_url, token=token)
204+
self.llm = partial(
205+
client.text_generation,
206+
max_new_tokens=max_new_tokens,
207+
details=log_probs,
208+
)

src/agentlab/llm/llm_utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import tiktoken
1919
import yaml
2020
from PIL import Image
21-
from transformers import AutoModel, AutoTokenizer
2221

2322
langchain_community = importlib.util.find_spec("langchain_community")
2423
if langchain_community is not None:
@@ -512,6 +511,13 @@ def get_tokenizer_old(model_name="openai/gpt-4"):
512511
)
513512
return tiktoken.encoding_for_model("gpt-4")
514513
else:
514+
# Lazy import of transformers only when needed
515+
try:
516+
from transformers import AutoTokenizer # type: ignore
517+
except Exception as e:
518+
raise ImportError(
519+
"The 'transformers' package is required to use non-OpenAI/Azure tokenizers."
520+
) from e
515521
return AutoTokenizer.from_pretrained(model_name)
516522

517523

@@ -522,6 +528,8 @@ def get_tokenizer(model_name="gpt-4"):
522528
except KeyError:
523529
logging.info(f"Could not find a tokenizer for model {model_name}. Trying HuggingFace.")
524530
try:
531+
from transformers import AutoTokenizer # type: ignore
532+
525533
return AutoTokenizer.from_pretrained(model_name)
526534
except Exception as e:
527535
logging.info(f"Could not find a tokenizer for model {model_name}: {e} Defaulting to gpt-4.")
@@ -676,8 +684,8 @@ def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False):
676684
retry_messages = []
677685

678686
for key in all_keys:
679-
if not key in content_dict:
680-
if not key in optional_keys:
687+
if key not in content_dict:
688+
if key not in optional_keys:
681689
retry_messages.append(f"Missing the key <{key}> in the answer.")
682690
else:
683691
val = content_dict[key]
@@ -697,6 +705,13 @@ def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False):
697705

698706

699707
def download_and_save_model(model_name: str, save_dir: str = "."):
708+
# Lazy import of transformers only when explicitly downloading a model
709+
try:
710+
from transformers import AutoModel # type: ignore
711+
except Exception as e:
712+
raise ImportError(
713+
"The 'transformers' package is required to download and save models."
714+
) from e
700715
model = AutoModel.from_pretrained(model_name)
701716
model.save_pretrained(save_dir)
702717
print(f"Model downloaded and saved to {save_dir}")

0 commit comments

Comments
 (0)