Skip to content

Commit b6d8d60

Browse files
committed
fix(hf): use workaround to use HF inference end points
huggingface/text-generation-inference#3250 langchain-ai/langchain#31434
1 parent a827a60 commit b6d8d60

File tree

1 file changed

+50
-29
lines changed
  • utils_pkg/neuroml_ai_utils

1 file changed

+50
-29
lines changed

utils_pkg/neuroml_ai_utils/llm.py

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
from langchain_core.messages import AIMessage, HumanMessage
2222
from langchain_core.output_parsers import JsonOutputParser
2323
from langchain_core.prompts import ChatPromptTemplate
24-
from langchain_huggingface import HuggingFaceEndpoint, HuggingFaceEndpointEmbeddings
24+
from langchain_huggingface import (
25+
ChatHuggingFace,
26+
HuggingFaceEndpoint,
27+
HuggingFaceEndpointEmbeddings,
28+
)
2529

2630

2731
def check_ollama_model(logger, model, exit=False):
@@ -78,34 +82,43 @@ def split_thought_and_output(message: AIMessage):
7882
return thoughts, answer
7983

8084

81-
def check_model_works(model, timeout=30, retries=20):
85+
def check_model_works(model, timeout=30, retries=5):
8286
"""Check if a model works since it is not tested when loaded"""
83-
if timeout < 0:
84-
return False, f"Invalid timeout {timeout}"
85-
elif timeout == 0:
87+
assert timeout >= 0
88+
89+
for attempt in range(retries):
90+
print(f"Checking model. Attempt #{attempt}")
8691
try:
87-
_ = model.invoke("test")
92+
# Use a very simple prompt with short max_tokens
93+
result = model.invoke("Hello world", config={"timeout": timeout})
94+
print(f"Model available (attempt {attempt + 1}): {result}")
95+
return True, f"Model available (attempt {attempt + 1})"
96+
except StopIteration as e:
97+
return (
98+
False,
99+
f"{e.__class__.__name__}: check if any inference providers are available for the selected model",
100+
)
88101
except Exception as e:
89-
error_msg = str(e)
90-
print(f"Model unavailable: {error_msg}")
91-
return False, f"Model unavailable: {error_msg}"
92-
else:
93-
for attempt in range(retries):
94-
try:
95-
# Use a very simple prompt with short max_tokens
96-
_ = model.invoke("test", config={"timeout": timeout})
97-
print(f"Model available (attempt {attempt + 1})")
98-
return True, f"Model available (attempt {attempt + 1})"
99-
except Exception as e:
100-
error_msg = str(e)
101-
if attempt < retries - 1:
102-
time.sleep(2**attempt) # Exponential backoff
103-
else:
104-
print(f"Model unavailable after {retries} attempts: {error_msg}")
105-
return (
106-
False,
107-
f"Model unavailable after {retries} attempts: {error_msg}",
108-
)
102+
error_msg = f"{e.__class__.__name__}: {e.__str__()}"
103+
print(f"Attempt #{attempt}: model unavailable: {error_msg}")
104+
if attempt < retries - 1:
105+
time.sleep(2**attempt) # Exponential backoff
106+
else:
107+
print(f"Model unavailable after {retries} attempts: {error_msg}")
108+
return (
109+
False,
110+
f"Model unavailable after {retries} attempts: {error_msg}",
111+
)
112+
error_msg = f"{e.__class__.__name__}: {e.__str__()}"
113+
print(f"Attempt #{attempt}: model unavailable: {error_msg}")
114+
if attempt < retries - 1:
115+
time.sleep(2**attempt) # Exponential backoff
116+
else:
117+
print(f"Model unavailable after {retries} attempts: {error_msg}")
118+
return (
119+
False,
120+
f"Model unavailable after {retries} attempts: {error_msg}",
121+
)
109122
return False, "Unknown error"
110123

111124

@@ -143,16 +156,20 @@ def setup_llm(model_name_full, logger):
143156
hf_token = os.environ.get("HF_TOKEN", None)
144157
assert hf_token
145158

146-
model_var = HuggingFaceEndpoint(
159+
logger.debug("Got HuggingFace Token")
160+
161+
llm = HuggingFaceEndpoint(
147162
repo_id=f"{model_name}",
148163
provider="auto",
149164
max_new_tokens=512,
150165
do_sample=False,
151166
repetition_penalty=1.03,
152-
task="text-generation",
167+
task="conversational", # seems to be ignored, defaults to text-generation
153168
huggingfacehub_api_token=hf_token,
154169
)
155170

171+
model_var = ChatHuggingFace(llm=llm)
172+
156173
"""
157174
158175
model_var = init_chat_model(
@@ -165,7 +182,11 @@ def setup_llm(model_name_full, logger):
165182
"""
166183
assert model_var
167184

168-
state, msg = check_model_works(model_var, timeout=60)
185+
state, msg = check_model_works(model_var, timeout=10, retries=3)
186+
if state:
187+
logger.debug(f"Model works: {state}, {msg}")
188+
else:
189+
logger.debug(f"Model does not work: {state}, {msg}")
169190
assert state
170191
else:
171192
if model_name_full.lower().startswith("ollama:"):

0 commit comments

Comments
 (0)