11import logging
2+ import os
23import time
4+ from functools import partial
35from typing import Any , List , Optional , Union
46
57from pydantic import Field
6- from transformers import AutoTokenizer , GPT2TokenizerFast
78
89from agentlab .llm .base_api import AbstractChatModel
910from agentlab .llm .llm_utils import AIMessage , Discussion
@@ -45,6 +46,14 @@ def __init__(self, model_name, base_model_name, n_retry_server, log_probs):
4546 self .n_retry_server = n_retry_server
4647 self .log_probs = log_probs
4748
49+ # Lazy import to avoid heavy transformers import when unused
50+ try :
51+ from transformers import AutoTokenizer , GPT2TokenizerFast # type: ignore
52+ except Exception as e : # pragma: no cover - surfaced only when transformers missing
53+ raise ImportError (
54+ "The 'transformers' package is required for HuggingFace models. Install it to use HF backends."
55+ ) from e
56+
4857 if base_model_name is None :
4958 self .tokenizer = AutoTokenizer .from_pretrained (model_name )
5059 else :
@@ -60,7 +69,7 @@ def __call__(
6069 self ,
6170 messages : list [dict ],
6271 n_samples : int = 1 ,
63- temperature : float = None ,
72+ temperature : Optional [ float ] = None ,
6473 ) -> Union [AIMessage , List [AIMessage ]]:
6574 """
6675 Generate one or more responses for the given messages.
@@ -85,7 +94,7 @@ def __call__(
8594 except Exception as e :
8695 if "Conversation roles must alternate" in str (e ):
8796 logging .warning (
88- f "Failed to apply the chat template. Maybe because it doesn't support the 'system' role. "
97+ "Failed to apply the chat template. Maybe because it doesn't support the 'system' role. "
8998 "Retrying with the 'system' role appended to the 'user' role."
9099 )
91100 messages = _prepend_system_to_first_user (messages )
@@ -100,7 +109,11 @@ def __call__(
100109 itr = 0
101110 while True :
102111 try :
103- temperature = temperature if temperature is not None else self .temperature
112+ temperature = (
113+ temperature
114+ if temperature is not None
115+ else getattr (self , "temperature" , 0.1 )
116+ )
104117 answer = self .llm (prompt , temperature = temperature )
105118 response = AIMessage (answer )
106119 if self .log_probs :
@@ -144,9 +157,52 @@ def _prepend_system_to_first_user(messages, column_remap={}):
144157 for msg in messages :
145158 if msg [role_key ] == human_key :
146159 # Prepend system content to the first user content
147- msg [text_key ] = system_content + "\n " + msg [text_key ]
160+ msg [text_key ] = str ( system_content ) + "\n " + str ( msg [text_key ])
148161 # Remove the original system message
149162 del messages [system_index ]
150163 break # Ensures that only the first user message is modified
151164
152165 return messages
166+
167+
168+ class HuggingFaceURLChatModel (HFBaseChatModel ):
169+ """HF backend using a Text Generation Inference (TGI) HTTP endpoint.
170+
171+ This class is placed here to keep all heavy HF imports optional and only
172+ loaded when a HF backend is explicitly requested.
173+ """
174+
175+ def __init__ (
176+ self ,
177+ model_name : str ,
178+ model_url : str ,
179+ base_model_name : Optional [str ] = None ,
180+ token : Optional [str ] = None ,
181+ temperature : Optional [float ] = 1e-1 ,
182+ max_new_tokens : Optional [int ] = 512 ,
183+ n_retry_server : Optional [int ] = 4 ,
184+ log_probs : Optional [bool ] = False ,
185+ ):
186+ super ().__init__ (model_name , base_model_name , n_retry_server , log_probs )
187+ if temperature is not None and temperature < 1e-3 :
188+ logging .warning ("Models might behave weirdly when temperature is too low." )
189+ self .temperature = temperature
190+
191+ if token is None :
192+ # support both env var names used elsewhere
193+ token = os .environ .get ("TGI_TOKEN" ) or os .environ .get ("AGENTLAB_MODEL_TOKEN" )
194+
195+ # Lazy import huggingface_hub here to avoid import on non-HF paths
196+ try :
197+ from huggingface_hub import InferenceClient # type: ignore
198+ except Exception as e : # pragma: no cover - surfaced only when package missing
199+ raise ImportError (
200+ "The 'huggingface_hub' package is required for HuggingFace URL backends."
201+ ) from e
202+
203+ client = InferenceClient (model = model_url , token = token )
204+ self .llm = partial (
205+ client .text_generation ,
206+ max_new_tokens = max_new_tokens ,
207+ details = log_probs ,
208+ )
0 commit comments