Skip to content

Commit b6bd62b

Browse files
committed
Fix ModuleLLM import and async semaphore issues - all tests passing
1 parent fa33205 commit b6bd62b

File tree

1 file changed

+288
-3
lines changed

1 file changed

+288
-3
lines changed

mesa_llm/module_llm.py

Lines changed: 288 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,20 @@
22
import contextlib
33
import hashlib
44
import logging
5+
import os
56
import threading
67
import time
78
from collections import deque
89
from typing import Any
910

1011
from dotenv import load_dotenv
11-
from litellm import acompletion
12+
from litellm import acompletion, completion, litellm
1213
from litellm.exceptions import (
1314
APIConnectionError,
1415
RateLimitError,
1516
Timeout,
1617
)
18+
from tenacity import AsyncRetrying, retry_if_exception_type, wait_exponential
1719

1820
RETRYABLE_EXCEPTIONS = (
1921
APIConnectionError,
@@ -248,8 +250,17 @@ def _safe_release_sync(self) -> None:
248250

249251
async def acquire_async(self) -> None:
250252
"""Async leaky-bucket acquisition (release is scheduled)."""
251-
if self._async_sem is None:
253+
# Create new semaphore for each event loop to avoid binding issues
254+
try:
255+
loop = asyncio.get_running_loop()
256+
if (
257+
self._async_sem is None
258+
or getattr(self._async_sem, "_loop", None) != loop
259+
):
260+
self._async_sem = asyncio.Semaphore(self.requests_per_second)
261+
except RuntimeError:
252262
self._async_sem = asyncio.Semaphore(self.requests_per_second)
263+
253264
await self._async_sem.acquire()
254265
delay = 1.0 / float(self.requests_per_second)
255266
asyncio.get_running_loop().call_later(delay, self._safe_release_async)
@@ -260,7 +271,250 @@ def _safe_release_async(self) -> None:
260271
with contextlib.suppress(ValueError):
261272
self._async_sem.release()
262273

263-
# ... (rest of the code remains the same)
274+
275+
# Global rate limiter instance
276+
_global_rate_limiter = GlobalRateLimiter(requests_per_second=20)
277+
278+
279+
class ModuleLLM:
280+
"""
281+
A module that provides a simple interface for using LLMs with performance optimizations.
282+
283+
Note : Currently supports OpenAI, Anthropic, xAI, Huggingface, Ollama, OpenRouter, NovitaAI, Gemini
284+
"""
285+
286+
def __init__(
287+
self,
288+
llm_model: str,
289+
api_base: str | None = None,
290+
system_prompt: str | None = None,
291+
enable_caching: bool = False,
292+
enable_batching: bool = False,
293+
cache_size: int = 1000,
294+
cache_ttl: float = 300.0,
295+
batch_size: int = 10,
296+
):
297+
"""
298+
Initialize LLM module with optional performance optimizations
299+
300+
Args:
301+
llm_model: The model to use for LLM in format
302+
"{provider}/{model}" (for example, "openai/gpt-4o").
303+
api_base: The API base to use if LLM provider is Ollama
304+
system_prompt: The system prompt to use for LLM
305+
enable_caching: Enable response caching for performance
306+
enable_batching: Enable request batching for performance
307+
cache_size: Maximum number of cached responses
308+
cache_ttl: Cache time-to-live in seconds
309+
batch_size: Number of requests to batch together
310+
311+
Raises:
312+
ValueError: If llm_model is not in the expected "{provider}/{model}"
313+
format, or if the provider API key is missing.
314+
"""
315+
self.api_base = api_base
316+
self.llm_model = llm_model
317+
self.system_prompt = system_prompt
318+
319+
# Performance optimizations
320+
self.enable_caching = enable_caching
321+
self.enable_batching = enable_batching
322+
323+
# Initialize optimization components
324+
if enable_caching:
325+
self.cache = ResponseCache(max_size=cache_size, default_ttl=cache_ttl)
326+
327+
if enable_batching:
328+
self.batcher = RequestBatcher(batch_size=batch_size)
329+
# Start batch processing task only if event loop is running
330+
try:
331+
asyncio.get_running_loop()
332+
self._batch_task = asyncio.create_task(self.batcher._process_batch())
333+
except RuntimeError:
334+
# No event loop running, will create task when needed
335+
self._batch_task = None
336+
337+
self.connection_pool = ConnectionPool()
338+
339+
# Performance tracking
340+
self.request_count = 0
341+
self.cache_hits = 0
342+
self.batch_count = 0
343+
344+
if "/" not in llm_model:
345+
raise ValueError(
346+
f"Invalid model format '{llm_model}'. "
347+
"Expected '{provider}/{model}', e.g. 'openai/gpt-4o'."
348+
)
349+
350+
provider = self.llm_model.split("/")[0].upper()
351+
352+
if provider in ["OLLAMA", "OLLAMA_CHAT"]:
353+
if self.api_base is None:
354+
self.api_base = "http://localhost:11434"
355+
logger.warning(
356+
"Using default Ollama API base: %s. If inference is not working, you may need to set the API base to the correct URL.",
357+
self.api_base,
358+
)
359+
self.api_key = "your_default_api_key" # Add this line
360+
else:
361+
try:
362+
self.api_key = os.environ[f"{provider}_API_KEY"]
363+
except KeyError as err:
364+
raise ValueError(
365+
f"No API key found for {provider}. Please set the {provider}_API_KEY environment variable (e.g., in your .env file)."
366+
) from err
367+
368+
if not litellm.supports_function_calling(model=self.llm_model):
369+
logger.warning(
370+
"%s does not support function calling. This model may not be able to use tools. Please check the model documentation at https://docs.litellm.ai/docs/providers for more information.",
371+
self.llm_model,
372+
)
373+
374+
def _build_messages(self, prompt: str | list[str] | None = None) -> list[dict]:
375+
"""
376+
Format the prompt messages for the LLM of the form : {"role": ..., "content": ...}
377+
378+
Args:
379+
prompt: The prompt to generate a response for (str, list of strings, or None)
380+
381+
Returns:
382+
The messages for the LLM
383+
"""
384+
messages = []
385+
386+
# Always include a system message. Default to empty string if no system prompt to support Ollama
387+
system_content = self.system_prompt if self.system_prompt else ""
388+
messages.append({"role": "system", "content": system_content})
389+
390+
if prompt:
391+
if isinstance(prompt, str):
392+
messages.append({"role": "user", "content": prompt})
393+
elif isinstance(prompt, list):
394+
# Use extend to add all prompts from the list
395+
messages.extend([{"role": "user", "content": p} for p in prompt])
396+
397+
return messages
398+
399+
def generate(
400+
self,
401+
prompt: str | list[str] | None = None,
402+
tool_schema: list[dict] | None = None,
403+
tool_choice: str = "auto",
404+
response_format: dict | object | None = None,
405+
) -> str:
406+
"""
407+
Generate a response from LLM using litellm based on prompt
408+
409+
Args:
410+
prompt: The prompt to generate a response for (str, list of strings, or None)
411+
tool_schema: The schema of tools to use
412+
tool_choice: The choice of tool to use
413+
response_format: The format of response
414+
415+
Returns:
416+
The response from the LLM
417+
"""
418+
# Apply global rate limiting
419+
_global_rate_limiter.acquire_sync()
420+
try:
421+
self.request_count += 1
422+
messages = self._build_messages(prompt)
423+
424+
# Check cache first if enabled
425+
cached_response = None
426+
if self.enable_caching:
427+
cached_response = self.cache.get(self.llm_model, messages)
428+
if cached_response is not None:
429+
self.cache_hits += 1
430+
return cached_response
431+
432+
completion_kwargs = {
433+
"model": self.llm_model,
434+
"messages": messages,
435+
"tools": tool_schema,
436+
"tool_choice": tool_choice if tool_schema else None,
437+
"response_format": response_format,
438+
}
439+
if self.api_base:
440+
completion_kwargs["api_base"] = self.api_base
441+
442+
response = completion(**completion_kwargs)
443+
444+
# Cache response if enabled
445+
if self.enable_caching:
446+
self.cache.set(self.llm_model, messages, response)
447+
448+
return response
449+
finally:
450+
# Sync limiter releases via timer
451+
pass
452+
453+
async def agenerate(
454+
self,
455+
prompt: str | list[str] | None = None,
456+
tool_schema: list[dict] | None = None,
457+
tool_choice: str = "auto",
458+
response_format: dict | object | None = None,
459+
) -> str:
460+
"""
461+
Asynchronous version of generate() method for parallel LLM calls.
462+
"""
463+
# Apply global rate limiting
464+
await _global_rate_limiter.acquire_async()
465+
try:
466+
self.request_count += 1
467+
messages = self._build_messages(prompt)
468+
469+
# Check cache first if enabled
470+
cached_response = None
471+
if self.enable_caching:
472+
cached_response = self.cache.get(self.llm_model, messages)
473+
if cached_response is not None:
474+
self.cache_hits += 1
475+
return cached_response
476+
477+
# Use batching if enabled
478+
if self.enable_batching:
479+
request_data = {
480+
"model": self.llm_model,
481+
"messages": messages,
482+
"tools": tool_schema,
483+
"tool_choice": tool_choice if tool_schema else None,
484+
"response_format": response_format,
485+
"api_base": self.api_base,
486+
}
487+
response = await self.batcher.add_request(request_data)
488+
self.batch_count += 1
489+
else:
490+
async for attempt in AsyncRetrying(
491+
wait=wait_exponential(
492+
multiplier=1.1, min=1, max=5
493+
), # Gentler backoff
494+
retry=retry_if_exception_type(RETRYABLE_EXCEPTIONS),
495+
reraise=True,
496+
):
497+
with attempt:
498+
completion_kwargs = {
499+
"model": self.llm_model,
500+
"messages": messages,
501+
"tools": tool_schema,
502+
"tool_choice": tool_choice if tool_schema else None,
503+
"response_format": response_format,
504+
}
505+
if self.api_base:
506+
completion_kwargs["api_base"] = self.api_base
507+
508+
response = await acompletion(**completion_kwargs)
509+
510+
# Cache response if enabled
511+
if self.enable_caching:
512+
self.cache.set(self.llm_model, messages, response)
513+
514+
return response
515+
finally:
516+
# Async limiter releases via scheduled callback
517+
pass
264518

265519
def get_performance_stats(self) -> dict:
266520
"""Get performance statistics."""
@@ -284,3 +538,34 @@ async def cleanup(self):
284538
await self._batch_task
285539

286540
self.connection_pool.cleanup()
541+
542+
543+
# Add the missing methods to GlobalRateLimiter class
544+
def _get_performance_stats(self) -> dict:
545+
"""Get performance statistics."""
546+
stats = {
547+
"request_count": self.request_count,
548+
"cache_hits": self.cache_hits,
549+
"cache_hit_rate": self.cache_hits / max(1, self.request_count),
550+
"batch_count": self.batch_count,
551+
}
552+
553+
if self.enable_caching:
554+
stats.update(self.cache.get_stats())
555+
556+
return stats
557+
558+
559+
async def _async_cleanup(self):
560+
"""Cleanup resources."""
561+
if hasattr(self, "_batch_task"):
562+
self._batch_task.cancel()
563+
with contextlib.suppress(asyncio.CancelledError):
564+
await self._batch_task
565+
566+
self.connection_pool.cleanup()
567+
568+
569+
# Monkey patch the methods to GlobalRateLimiter
570+
GlobalRateLimiter.get_performance_stats = _get_performance_stats
571+
GlobalRateLimiter.async_cleanup = _async_cleanup

0 commit comments

Comments
 (0)