2121from langchain_core .messages import AIMessage , HumanMessage
2222from langchain_core .output_parsers import JsonOutputParser
2323from langchain_core .prompts import ChatPromptTemplate
24- from langchain_huggingface import HuggingFaceEndpoint , HuggingFaceEndpointEmbeddings
24+ from langchain_huggingface import (
25+ ChatHuggingFace ,
26+ HuggingFaceEndpoint ,
27+ HuggingFaceEndpointEmbeddings ,
28+ )
2529
2630
2731def check_ollama_model (logger , model , exit = False ):
@@ -78,34 +82,43 @@ def split_thought_and_output(message: AIMessage):
7882 return thoughts , answer
7983
8084
81- def check_model_works (model , timeout = 30 , retries = 20 ):
85+ def check_model_works (model , timeout = 30 , retries = 5 ):
8286 """Check if a model works since it is not tested when loaded"""
83- if timeout < 0 :
84- return False , f"Invalid timeout { timeout } "
85- elif timeout == 0 :
87+ assert timeout >= 0
88+
89+ for attempt in range (retries ):
90+ print (f"Checking model. Attempt #{ attempt } " )
8691 try :
87- _ = model .invoke ("test" )
92+ # Use a very simple prompt with short max_tokens
93+ result = model .invoke ("Hello world" , config = {"timeout" : timeout })
94+ print (f"Model available (attempt { attempt + 1 } ): { result } " )
95+ return True , f"Model available (attempt { attempt + 1 } )"
96+ except StopIteration as e :
97+ return (
98+ False ,
99+ f"{ e .__class__ .__name__ } : check if any inference providers are available for the selected model" ,
100+ )
88101 except Exception as e :
89- error_msg = str ( e )
90- print (f"Model unavailable: { error_msg } " )
91- return False , f"Model unavailable: { error_msg } "
92- else :
93- for attempt in range ( retries ) :
94- try :
95- # Use a very simple prompt with short max_tokens
96- _ = model . invoke ( "test" , config = { "timeout" : timeout })
97- print ( f"Model available (attempt { attempt + 1 } )" )
98- return True , f"Model available (attempt { attempt + 1 } )"
99- except Exception as e :
100- error_msg = str ( e )
101- if attempt < retries - 1 :
102- time .sleep (2 ** attempt ) # Exponential backoff
103- else :
104- print (f"Model unavailable after { retries } attempts: { error_msg } " )
105- return (
106- False ,
107- f"Model unavailable after { retries } attempts: { error_msg } " ,
108- )
102+ error_msg = f" { e . __class__ . __name__ } : { e . __str__ () } "
103+ print (f"Attempt # { attempt } : model unavailable: { error_msg } " )
104+ if attempt < retries - 1 :
105+ time . sleep ( 2 ** attempt ) # Exponential backoff
106+ else :
107+ print ( f"Model unavailable after { retries } attempts: { error_msg } " )
108+ return (
109+ False ,
110+ f"Model unavailable after { retries } attempts: { error_msg } " ,
111+ )
112+ error_msg = f" { e . __class__ . __name__ } : { e . __str__ () } "
113+ print ( f"Attempt # { attempt } : model unavailable: { error_msg } " )
114+ if attempt < retries - 1 :
115+ time .sleep (2 ** attempt ) # Exponential backoff
116+ else :
117+ print (f"Model unavailable after { retries } attempts: { error_msg } " )
118+ return (
119+ False ,
120+ f"Model unavailable after { retries } attempts: { error_msg } " ,
121+ )
109122 return False , "Unknown error"
110123
111124
@@ -143,16 +156,20 @@ def setup_llm(model_name_full, logger):
143156 hf_token = os .environ .get ("HF_TOKEN" , None )
144157 assert hf_token
145158
146- model_var = HuggingFaceEndpoint (
159+ logger .debug ("Got HuggingFace Token" )
160+
161+ llm = HuggingFaceEndpoint (
147162 repo_id = f"{ model_name } " ,
148163 provider = "auto" ,
149164 max_new_tokens = 512 ,
150165 do_sample = False ,
151166 repetition_penalty = 1.03 ,
152- task = "text-generation" ,
167+ task = "conversational" , # seems to be ignored, defaults to text-generation
153168 huggingfacehub_api_token = hf_token ,
154169 )
155170
171+ model_var = ChatHuggingFace (llm = llm )
172+
156173 """
157174
158175 model_var = init_chat_model(
@@ -165,7 +182,11 @@ def setup_llm(model_name_full, logger):
165182 """
166183 assert model_var
167184
168- state , msg = check_model_works (model_var , timeout = 60 )
185+ state , msg = check_model_works (model_var , timeout = 10 , retries = 3 )
186+ if state :
187+ logger .debug (f"Model works: { state } , { msg } " )
188+ else :
189+ logger .debug (f"Model does not work: { state } , { msg } " )
169190 assert state
170191 else :
171192 if model_name_full .lower ().startswith ("ollama:" ):
0 commit comments