Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ license = "Apache-2.0"
license-files = ["LICENSE"]
dependencies = [
"jsonpath-ng>=1.7.0",
"litellm==1.71.3",
"litellm>=1.77.2",
"numpy>=2.1.3",
"openai==1.87.0",
"openai>=2.0.0",
"pandas>=2.0.0",
"pandas-stubs>=2.0.0",
"pydantic>=2.9.2",
Expand Down
53 changes: 53 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import pytest
from dotenv import load_dotenv
import litellm
import asyncio
import contextlib
from collections.abc import Coroutine, Iterator
from importlib.metadata import version
from unittest.mock import patch
import vcr # type: ignore[import-untyped]

from tests.helpers.completion_template_fixtures import (
Expand Down Expand Up @@ -41,6 +46,54 @@ def setup_logging():
yield


@pytest.fixture(autouse=True, scope="session")
def patch_litellm_logging_worker_for_race_condition() -> Iterator[None]:
"""
Patch litellm's GLOBAL_LOGGING_WORKER for asyncio functionality.

SEE: https://github.com/BerriAI/litellm/issues/16518
SEE: https://github.com/BerriAI/litellm/issues/14521
"""
try:
from litellm.litellm_core_utils import logging_worker
except ImportError:
if tuple(int(x) for x in version(litellm.__name__).split(".")) < (1, 76, 0):
# Module didn't exist before https://github.com/BerriAI/litellm/pull/13905
yield
return
raise

class NoOpLoggingWorker:
"""No-op worker that executes callbacks immediately without queuing."""

def start(self) -> None:
pass

def enqueue(self, coroutine: Coroutine) -> None:
# Execute immediately in current loop instead of queueing,
# and do nothing if there's no current loop
with contextlib.suppress(RuntimeError):
# This logging task is fire-and-forget
asyncio.create_task( # type: ignore[unused-awaitable] # noqa: RUF006
coroutine
)

def ensure_initialized_and_enqueue(self, async_coroutine: Coroutine) -> None:
self.enqueue(async_coroutine)

async def stop(self) -> None:
pass

async def flush(self) -> None:
pass

async def clear_queue(self) -> None:
pass

with patch.object(logging_worker, "GLOBAL_LOGGING_WORKER", NoOpLoggingWorker()):
yield


@pytest.fixture(scope="function")
async def vcr_cassette_async(request: pytest.FixtureRequest):
"""
Expand Down
7 changes: 4 additions & 3 deletions tlm/utils/completion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,11 @@ async def _generate_completion(

if litellm_params.get("logprobs") and hasattr(response.choices[0], "logprobs"):
# Convert ChoiceLogprobs to dict to avoid Pydantic validation issues
choice_logprobs = response.choices[0].logprobs
logprobs = ChoiceLogprobs.model_validate(
response.choices[0].logprobs.model_dump()
if hasattr(response.choices[0].logprobs, "model_dump")
else response.choices[0].logprobs
choice_logprobs.model_dump()
if choice_logprobs and hasattr(choice_logprobs, "model_dump")
else choice_logprobs
)

if raw_message_content := _get_raw_message_content(logprobs):
Expand Down
Loading