22import contextlib
33import hashlib
44import logging
5+ import os
56import threading
67import time
78from collections import deque
89from typing import Any
910
1011from dotenv import load_dotenv
11- from litellm import acompletion
12+ from litellm import acompletion , completion , litellm
1213from litellm .exceptions import (
1314 APIConnectionError ,
1415 RateLimitError ,
1516 Timeout ,
1617)
18+ from tenacity import AsyncRetrying , retry_if_exception_type , wait_exponential
1719
1820RETRYABLE_EXCEPTIONS = (
1921 APIConnectionError ,
@@ -248,8 +250,17 @@ def _safe_release_sync(self) -> None:
248250
249251 async def acquire_async (self ) -> None :
250252 """Async leaky-bucket acquisition (release is scheduled)."""
251- if self ._async_sem is None :
253+ # Create new semaphore for each event loop to avoid binding issues
254+ try :
255+ loop = asyncio .get_running_loop ()
256+ if (
257+ self ._async_sem is None
258+ or getattr (self ._async_sem , "_loop" , None ) != loop
259+ ):
260+ self ._async_sem = asyncio .Semaphore (self .requests_per_second )
261+ except RuntimeError :
252262 self ._async_sem = asyncio .Semaphore (self .requests_per_second )
263+
253264 await self ._async_sem .acquire ()
254265 delay = 1.0 / float (self .requests_per_second )
255266 asyncio .get_running_loop ().call_later (delay , self ._safe_release_async )
@@ -260,7 +271,250 @@ def _safe_release_async(self) -> None:
260271 with contextlib .suppress (ValueError ):
261272 self ._async_sem .release ()
262273
263- # ... (rest of the code remains the same)
274+
275+ # Global rate limiter instance
276+ _global_rate_limiter = GlobalRateLimiter (requests_per_second = 20 )
277+
278+
279+ class ModuleLLM :
280+ """
281+ A module that provides a simple interface for using LLMs with performance optimizations.
282+
283+ Note : Currently supports OpenAI, Anthropic, xAI, Huggingface, Ollama, OpenRouter, NovitaAI, Gemini
284+ """
285+
286+ def __init__ (
287+ self ,
288+ llm_model : str ,
289+ api_base : str | None = None ,
290+ system_prompt : str | None = None ,
291+ enable_caching : bool = False ,
292+ enable_batching : bool = False ,
293+ cache_size : int = 1000 ,
294+ cache_ttl : float = 300.0 ,
295+ batch_size : int = 10 ,
296+ ):
297+ """
298+ Initialize LLM module with optional performance optimizations
299+
300+ Args:
301+ llm_model: The model to use for LLM in format
302+ "{provider}/{model}" (for example, "openai/gpt-4o").
303+ api_base: The API base to use if LLM provider is Ollama
304+ system_prompt: The system prompt to use for LLM
305+ enable_caching: Enable response caching for performance
306+ enable_batching: Enable request batching for performance
307+ cache_size: Maximum number of cached responses
308+ cache_ttl: Cache time-to-live in seconds
309+ batch_size: Number of requests to batch together
310+
311+ Raises:
312+ ValueError: If llm_model is not in the expected "{provider}/{model}"
313+ format, or if the provider API key is missing.
314+ """
315+ self .api_base = api_base
316+ self .llm_model = llm_model
317+ self .system_prompt = system_prompt
318+
319+ # Performance optimizations
320+ self .enable_caching = enable_caching
321+ self .enable_batching = enable_batching
322+
323+ # Initialize optimization components
324+ if enable_caching :
325+ self .cache = ResponseCache (max_size = cache_size , default_ttl = cache_ttl )
326+
327+ if enable_batching :
328+ self .batcher = RequestBatcher (batch_size = batch_size )
329+ # Start batch processing task only if event loop is running
330+ try :
331+ asyncio .get_running_loop ()
332+ self ._batch_task = asyncio .create_task (self .batcher ._process_batch ())
333+ except RuntimeError :
334+ # No event loop running, will create task when needed
335+ self ._batch_task = None
336+
337+ self .connection_pool = ConnectionPool ()
338+
339+ # Performance tracking
340+ self .request_count = 0
341+ self .cache_hits = 0
342+ self .batch_count = 0
343+
344+ if "/" not in llm_model :
345+ raise ValueError (
346+ f"Invalid model format '{ llm_model } '. "
347+ "Expected '{provider}/{model}', e.g. 'openai/gpt-4o'."
348+ )
349+
350+ provider = self .llm_model .split ("/" )[0 ].upper ()
351+
352+ if provider in ["OLLAMA" , "OLLAMA_CHAT" ]:
353+ if self .api_base is None :
354+ self .api_base = "http://localhost:11434"
355+ logger .warning (
356+ "Using default Ollama API base: %s. If inference is not working, you may need to set the API base to the correct URL." ,
357+ self .api_base ,
358+ )
359+ self .api_key = "your_default_api_key" # Add this line
360+ else :
361+ try :
362+ self .api_key = os .environ [f"{ provider } _API_KEY" ]
363+ except KeyError as err :
364+ raise ValueError (
365+ f"No API key found for { provider } . Please set the { provider } _API_KEY environment variable (e.g., in your .env file)."
366+ ) from err
367+
368+ if not litellm .supports_function_calling (model = self .llm_model ):
369+ logger .warning (
370+ "%s does not support function calling. This model may not be able to use tools. Please check the model documentation at https://docs.litellm.ai/docs/providers for more information." ,
371+ self .llm_model ,
372+ )
373+
374+ def _build_messages (self , prompt : str | list [str ] | None = None ) -> list [dict ]:
375+ """
376+ Format the prompt messages for the LLM of the form : {"role": ..., "content": ...}
377+
378+ Args:
379+ prompt: The prompt to generate a response for (str, list of strings, or None)
380+
381+ Returns:
382+ The messages for the LLM
383+ """
384+ messages = []
385+
386+ # Always include a system message. Default to empty string if no system prompt to support Ollama
387+ system_content = self .system_prompt if self .system_prompt else ""
388+ messages .append ({"role" : "system" , "content" : system_content })
389+
390+ if prompt :
391+ if isinstance (prompt , str ):
392+ messages .append ({"role" : "user" , "content" : prompt })
393+ elif isinstance (prompt , list ):
394+ # Use extend to add all prompts from the list
395+ messages .extend ([{"role" : "user" , "content" : p } for p in prompt ])
396+
397+ return messages
398+
399+ def generate (
400+ self ,
401+ prompt : str | list [str ] | None = None ,
402+ tool_schema : list [dict ] | None = None ,
403+ tool_choice : str = "auto" ,
404+ response_format : dict | object | None = None ,
405+ ) -> str :
406+ """
407+ Generate a response from LLM using litellm based on prompt
408+
409+ Args:
410+ prompt: The prompt to generate a response for (str, list of strings, or None)
411+ tool_schema: The schema of tools to use
412+ tool_choice: The choice of tool to use
413+ response_format: The format of response
414+
415+ Returns:
416+ The response from the LLM
417+ """
418+ # Apply global rate limiting
419+ _global_rate_limiter .acquire_sync ()
420+ try :
421+ self .request_count += 1
422+ messages = self ._build_messages (prompt )
423+
424+ # Check cache first if enabled
425+ cached_response = None
426+ if self .enable_caching :
427+ cached_response = self .cache .get (self .llm_model , messages )
428+ if cached_response is not None :
429+ self .cache_hits += 1
430+ return cached_response
431+
432+ completion_kwargs = {
433+ "model" : self .llm_model ,
434+ "messages" : messages ,
435+ "tools" : tool_schema ,
436+ "tool_choice" : tool_choice if tool_schema else None ,
437+ "response_format" : response_format ,
438+ }
439+ if self .api_base :
440+ completion_kwargs ["api_base" ] = self .api_base
441+
442+ response = completion (** completion_kwargs )
443+
444+ # Cache response if enabled
445+ if self .enable_caching :
446+ self .cache .set (self .llm_model , messages , response )
447+
448+ return response
449+ finally :
450+ # Sync limiter releases via timer
451+ pass
452+
453+ async def agenerate (
454+ self ,
455+ prompt : str | list [str ] | None = None ,
456+ tool_schema : list [dict ] | None = None ,
457+ tool_choice : str = "auto" ,
458+ response_format : dict | object | None = None ,
459+ ) -> str :
460+ """
461+ Asynchronous version of generate() method for parallel LLM calls.
462+ """
463+ # Apply global rate limiting
464+ await _global_rate_limiter .acquire_async ()
465+ try :
466+ self .request_count += 1
467+ messages = self ._build_messages (prompt )
468+
469+ # Check cache first if enabled
470+ cached_response = None
471+ if self .enable_caching :
472+ cached_response = self .cache .get (self .llm_model , messages )
473+ if cached_response is not None :
474+ self .cache_hits += 1
475+ return cached_response
476+
477+ # Use batching if enabled
478+ if self .enable_batching :
479+ request_data = {
480+ "model" : self .llm_model ,
481+ "messages" : messages ,
482+ "tools" : tool_schema ,
483+ "tool_choice" : tool_choice if tool_schema else None ,
484+ "response_format" : response_format ,
485+ "api_base" : self .api_base ,
486+ }
487+ response = await self .batcher .add_request (request_data )
488+ self .batch_count += 1
489+ else :
490+ async for attempt in AsyncRetrying (
491+ wait = wait_exponential (
492+ multiplier = 1.1 , min = 1 , max = 5
493+ ), # Gentler backoff
494+ retry = retry_if_exception_type (RETRYABLE_EXCEPTIONS ),
495+ reraise = True ,
496+ ):
497+ with attempt :
498+ completion_kwargs = {
499+ "model" : self .llm_model ,
500+ "messages" : messages ,
501+ "tools" : tool_schema ,
502+ "tool_choice" : tool_choice if tool_schema else None ,
503+ "response_format" : response_format ,
504+ }
505+ if self .api_base :
506+ completion_kwargs ["api_base" ] = self .api_base
507+
508+ response = await acompletion (** completion_kwargs )
509+
510+ # Cache response if enabled
511+ if self .enable_caching :
512+ self .cache .set (self .llm_model , messages , response )
513+
514+ return response
515+ finally :
516+ # Async limiter releases via scheduled callback
517+ pass
264518
265519 def get_performance_stats (self ) -> dict :
266520 """Get performance statistics."""
@@ -284,3 +538,34 @@ async def cleanup(self):
284538 await self ._batch_task
285539
286540 self .connection_pool .cleanup ()
541+
542+
543+ # Add the missing methods to GlobalRateLimiter class
544+ def _get_performance_stats (self ) -> dict :
545+ """Get performance statistics."""
546+ stats = {
547+ "request_count" : self .request_count ,
548+ "cache_hits" : self .cache_hits ,
549+ "cache_hit_rate" : self .cache_hits / max (1 , self .request_count ),
550+ "batch_count" : self .batch_count ,
551+ }
552+
553+ if self .enable_caching :
554+ stats .update (self .cache .get_stats ())
555+
556+ return stats
557+
558+
559+ async def _async_cleanup (self ):
560+ """Cleanup resources."""
561+ if hasattr (self , "_batch_task" ):
562+ self ._batch_task .cancel ()
563+ with contextlib .suppress (asyncio .CancelledError ):
564+ await self ._batch_task
565+
566+ self .connection_pool .cleanup ()
567+
568+
569+ # Monkey patch the methods to GlobalRateLimiter
570+ GlobalRateLimiter .get_performance_stats = _get_performance_stats
571+ GlobalRateLimiter .async_cleanup = _async_cleanup
0 commit comments