Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ license = "Apache-2.0"
license-files = ["LICENSE"]
dependencies = [
"jsonpath-ng>=1.7.0",
"litellm==1.71.3",
"litellm>=1.77.2",
"numpy>=2.1.3",
"openai==1.87.0",
"openai>=2.0.0",
"pandas>=2.0.0",
"pandas-stubs>=2.0.0",
"pydantic>=2.9.2",
Expand Down
53 changes: 53 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import pytest
from dotenv import load_dotenv
import litellm
import asyncio
import contextlib
from collections.abc import Coroutine, Iterator
from importlib.metadata import version
from unittest.mock import patch
import vcr # type: ignore[import-untyped]

from tests.helpers.completion_template_fixtures import (
Expand Down Expand Up @@ -41,6 +46,54 @@ def setup_logging():
yield


@pytest.fixture(autouse=True, scope="session")
def patch_litellm_logging_worker_for_race_condition() -> Iterator[None]:
"""
Patch litellm's GLOBAL_LOGGING_WORKER for asyncio functionality.

SEE: https://github.com/BerriAI/litellm/issues/16518
SEE: https://github.com/BerriAI/litellm/issues/14521
"""
try:
from litellm.litellm_core_utils import logging_worker
except ImportError:
if tuple(int(x) for x in version(litellm.__name__).split(".")) < (1, 76, 0):
# Module didn't exist before https://github.com/BerriAI/litellm/pull/13905
yield
return
raise

class NoOpLoggingWorker:
"""No-op worker that executes callbacks immediately without queuing."""

def start(self) -> None:
pass

def enqueue(self, coroutine: Coroutine) -> None:
# Execute immediately in current loop instead of queueing,
# and do nothing if there's no current loop
with contextlib.suppress(RuntimeError):
# This logging task is fire-and-forget
asyncio.create_task( # type: ignore[unused-awaitable] # noqa: RUF006
coroutine
)

def ensure_initialized_and_enqueue(self, async_coroutine: Coroutine) -> None:
self.enqueue(async_coroutine)

async def stop(self) -> None:
pass

async def flush(self) -> None:
pass

async def clear_queue(self) -> None:
pass

with patch.object(logging_worker, "GLOBAL_LOGGING_WORKER", NoOpLoggingWorker()):
yield


@pytest.fixture(scope="function")
async def vcr_cassette_async(request: pytest.FixtureRequest):
"""
Expand Down
2 changes: 1 addition & 1 deletion tlm/utils/completion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ async def _generate_completion(
# Convert ChoiceLogprobs to dict to avoid Pydantic validation issues
logprobs = ChoiceLogprobs.model_validate(
response.choices[0].logprobs.model_dump()
if hasattr(response.choices[0].logprobs, "model_dump")
if response.choices[0].logprobs and hasattr(response.choices[0].logprobs, "model_dump")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this check necessary? it's already within the condition if litellm_params.get("logprobs") and hasattr(response.choices[0], "logprobs"):

else response.choices[0].logprobs
)

Expand Down
Loading