@@ -73,7 +73,7 @@ def _ensure_connection(self):
7373 def _run_coroutine (self , coro , timeout : Optional [float ] = None ):
7474 """Run a coroutine in the event loop thread with timeout and error handling."""
7575 self ._ensure_connection ()
76-
76+
7777 try :
7878 future = asyncio .run_coroutine_threadsafe (coro , self ._loop )
7979 return future .result (timeout = timeout )
@@ -123,10 +123,10 @@ def close(self):
123123 # Clean up
124124 self ._loop = None
125125 self ._thread = None
126-
126+
127127 # Shutdown executor
128128 self ._executor .shutdown (wait = False )
129-
129+
130130 except Exception as e :
131131 logger .error (f"Error during client cleanup: { str (e )} " )
132132 finally :
@@ -139,25 +139,40 @@ def generate(
139139 stream : bool = False ,
140140 max_length : Optional [int ] = None ,
141141 temperature : float = 0.7 ,
142- top_p : float = 0.9
142+ top_p : float = 0.9 ,
143+ repetition_penalty : float = 1.15 , # Increased repetition penalty for better quality
144+ top_k : int = 80 # Added top_k parameter for better quality
143145 ) -> Union [str , Generator [str , None , None ]]:
144146 """
145- Generate text using the model.
147+ Generate text using the model with improved quality settings .
146148
147149 Args:
148150 prompt: The prompt to generate text from
149151 model_id: Optional model ID to use
150152 stream: Whether to stream the response
151- max_length: Maximum length of the generated text
153+ max_length: Maximum length of the generated text (defaults to 1024 if None)
152154 temperature: Temperature for sampling
153155 top_p: Top-p for nucleus sampling
156+ repetition_penalty: Penalty for repetition (higher values = less repetition)
154157
155158 Returns:
156159 If stream=False, returns the generated text as a string.
157160 If stream=True, returns a generator that yields chunks of text.
158161 """
162+ # Use a higher max_length by default to ensure complete responses
163+ if max_length is None :
164+ max_length = 4096 # Default to 4096 tokens for more complete responses
165+
159166 if stream :
160- return self .stream_generate (prompt , model_id , max_length , temperature , top_p )
167+ return self .stream_generate (
168+ prompt = prompt ,
169+ model_id = model_id ,
170+ max_length = max_length ,
171+ temperature = temperature ,
172+ top_p = top_p ,
173+ repetition_penalty = repetition_penalty ,
174+ top_k = top_k
175+ )
161176
162177 return self ._run_coroutine (
163178 self ._async_client .generate (
@@ -166,7 +181,10 @@ def generate(
166181 stream = False ,
167182 max_length = max_length ,
168183 temperature = temperature ,
169- top_p = top_p
184+ top_p = top_p ,
185+ repetition_penalty = repetition_penalty ,
186+ top_k = top_k ,
187+ timeout = 180.0 # Increased timeout for more complete responses (3 minutes)
170188 )
171189 )
172190
@@ -177,22 +195,29 @@ def stream_generate(
177195 max_length : Optional [int ] = None ,
178196 temperature : float = 0.7 ,
179197 top_p : float = 0.9 ,
180- timeout : float = 60.0
198+ timeout : float = 300.0 , # Increased timeout for more complete responses (5 minutes)
199+ repetition_penalty : float = 1.15 , # Increased repetition penalty for better quality
200+ top_k : int = 80 # Added top_k parameter for better quality
181201 ) -> Generator [str , None , None ]:
182202 """
183- Stream text generation.
203+ Stream text generation with improved quality and reliability .
184204
185205 Args:
186206 prompt: The prompt to generate text from
187207 model_id: Optional model ID to use
188- max_length: Maximum length of the generated text
208+ max_length: Maximum length of the generated text (defaults to 1024 if None)
189209 temperature: Temperature for sampling
190210 top_p: Top-p for nucleus sampling
191211 timeout: Request timeout in seconds
212+ repetition_penalty: Penalty for repetition (higher values = less repetition)
192213
193214 Returns:
194215 A generator that yields chunks of text as they are generated.
195216 """
217+ # Use a higher max_length by default to ensure complete responses
218+ if max_length is None :
219+ max_length = 4096 # Default to 4096 tokens for more complete responses
220+
196221 # Create a queue to pass data between the async and sync worlds
197222 queue = asyncio .Queue ()
198223 stop_event = threading .Event ()
@@ -206,7 +231,10 @@ async def producer():
206231 max_length = max_length ,
207232 temperature = temperature ,
208233 top_p = top_p ,
209- timeout = timeout
234+ timeout = timeout ,
235+ retry_count = 3 , # Increased retry count for better reliability
236+ repetition_penalty = repetition_penalty , # Pass the repetition penalty parameter
237+ top_k = top_k # Pass the top_k parameter
210238 ):
211239 await queue .put (chunk )
212240
@@ -250,25 +278,41 @@ def chat(
250278 stream : bool = False ,
251279 max_length : Optional [int ] = None ,
252280 temperature : float = 0.7 ,
253- top_p : float = 0.9
281+ top_p : float = 0.9 ,
282+ repetition_penalty : float = 1.15 , # Increased repetition penalty for better quality
283+ top_k : int = 80 # Added top_k parameter for better quality
254284 ) -> Union [Dict [str , Any ], Generator [Dict [str , Any ], None , None ]]:
255285 """
256- Chat completion.
286+ Chat completion with improved quality settings .
257287
258288 Args:
259289 messages: List of message dictionaries with 'role' and 'content' keys
260290 model_id: Optional model ID to use
261291 stream: Whether to stream the response
262- max_length: Maximum length of the generated text
292+ max_length: Maximum length of the generated text (defaults to 1024 if None)
263293 temperature: Temperature for sampling
264294 top_p: Top-p for nucleus sampling
295+ repetition_penalty: Penalty for repetition (higher values = less repetition)
265296
266297 Returns:
267298 If stream=False, returns the chat completion response.
268299 If stream=True, returns a generator that yields chunks of the response.
269300 """
301+ # Use a higher max_length by default to ensure complete responses
302+ if max_length is None :
303+ max_length = 4096 # Default to 4096 tokens for more complete responses
304+
270305 if stream :
271- return self .stream_chat (messages , model_id , max_length , temperature , top_p )
306+ return self .stream_chat (
307+ messages = messages ,
308+ model_id = model_id ,
309+ max_length = max_length ,
310+ temperature = temperature ,
311+ top_p = top_p ,
312+ timeout = 300.0 , # Increased timeout for more complete responses (5 minutes)
313+ repetition_penalty = repetition_penalty ,
314+ top_k = top_k
315+ )
272316
273317 return self ._run_coroutine (
274318 self ._async_client .chat (
@@ -277,7 +321,10 @@ def chat(
277321 stream = False ,
278322 max_length = max_length ,
279323 temperature = temperature ,
280- top_p = top_p
324+ top_p = top_p ,
325+ timeout = 180.0 , # Increased timeout for more complete responses (3 minutes)
326+ repetition_penalty = repetition_penalty ,
327+ top_k = top_k
281328 )
282329 )
283330
@@ -287,21 +334,29 @@ def stream_chat(
287334 model_id : Optional [str ] = None ,
288335 max_length : Optional [int ] = None ,
289336 temperature : float = 0.7 ,
290- top_p : float = 0.9
337+ top_p : float = 0.9 ,
338+ timeout : float = 300.0 , # Increased timeout for more complete responses (5 minutes)
339+ repetition_penalty : float = 1.15 , # Added repetition penalty for better quality
340+ top_k : int = 80 # Added top_k parameter for better quality
291341 ) -> Generator [Dict [str , Any ], None , None ]:
292342 """
293- Stream chat completion.
343+ Stream chat completion with improved quality and reliability .
294344
295345 Args:
296346 messages: List of message dictionaries with 'role' and 'content' keys
297347 model_id: Optional model ID to use
298- max_length: Maximum length of the generated text
348+ max_length: Maximum length of the generated text (defaults to 1024 if None)
299349 temperature: Temperature for sampling
300350 top_p: Top-p for nucleus sampling
351+ timeout: Request timeout in seconds
301352
302353 Returns:
303354 A generator that yields chunks of the chat completion response.
304355 """
356+ # Use a higher max_length by default to ensure complete responses
357+ if max_length is None :
358+ max_length = 4096 # Default to 4096 tokens for more complete responses
359+
305360 # Create a queue to pass data between the async and sync worlds
306361 queue = asyncio .Queue ()
307362 stop_event = threading .Event ()
@@ -314,7 +369,11 @@ async def producer():
314369 model_id = model_id ,
315370 max_length = max_length ,
316371 temperature = temperature ,
317- top_p = top_p
372+ top_p = top_p ,
373+ timeout = timeout ,
374+ retry_count = 3 , # Increased retry count for better reliability
375+ repetition_penalty = repetition_penalty ,
376+ top_k = top_k
318377 ):
319378 await queue .put (chunk )
320379
@@ -357,28 +416,39 @@ def batch_generate(
357416 model_id : Optional [str ] = None ,
358417 max_length : Optional [int ] = None ,
359418 temperature : float = 0.7 ,
360- top_p : float = 0.9
419+ top_p : float = 0.9 ,
420+ repetition_penalty : float = 1.15 , # Increased repetition penalty for better quality
421+ top_k : int = 80 , # Added top_k parameter for better quality
422+ timeout : float = 300.0 # Added timeout parameter (5 minutes)
361423 ) -> Dict [str , List [str ]]:
362424 """
363- Generate text for multiple prompts in parallel.
425+ Generate text for multiple prompts in parallel with improved quality settings .
364426
365427 Args:
366428 prompts: List of prompts to generate text from
367429 model_id: Optional model ID to use
368- max_length: Maximum length of the generated text
430+ max_length: Maximum length of the generated text (defaults to 1024 if None)
369431 temperature: Temperature for sampling
370432 top_p: Top-p for nucleus sampling
433+ repetition_penalty: Penalty for repetition (higher values = less repetition)
371434
372435 Returns:
373436 Dictionary with the generated responses.
374437 """
438+ # Use a higher max_length by default to ensure complete responses
439+ if max_length is None :
440+ max_length = 4096 # Default to 4096 tokens for more complete responses
441+
375442 return self ._run_coroutine (
376443 self ._async_client .batch_generate (
377444 prompts = prompts ,
378445 model_id = model_id ,
379446 max_length = max_length ,
380447 temperature = temperature ,
381- top_p = top_p
448+ top_p = top_p ,
449+ repetition_penalty = repetition_penalty ,
450+ top_k = top_k ,
451+ timeout = timeout # Use the provided timeout parameter
382452 )
383453 )
384454
0 commit comments