|
| 1 | +from __future__ import annotations as _annotations |
| 2 | + |
| 3 | +from collections.abc import Sequence |
| 4 | +from dataclasses import dataclass, field |
| 5 | +from functools import cached_property |
| 6 | +from typing import Any, Union, cast |
| 7 | + |
| 8 | +from ..conftest import raise_if_exception, try_import |
| 9 | +from .mock_async_stream import MockAsyncStream |
| 10 | + |
| 11 | +with try_import() as imports_successful: |
| 12 | + from openai import NOT_GIVEN, AsyncOpenAI |
| 13 | + from openai.types import chat, responses |
| 14 | + from openai.types.chat.chat_completion import Choice, ChoiceLogprobs |
| 15 | + from openai.types.chat.chat_completion_message import ChatCompletionMessage |
| 16 | + from openai.types.completion_usage import CompletionUsage |
| 17 | + from openai.types.responses.response import ResponseUsage |
| 18 | + from openai.types.responses.response_output_item import ResponseOutputItem |
| 19 | + |
| 20 | + # note: we use Union here so that casting works with Python 3.9 |
| 21 | + MockChatCompletion = Union[chat.ChatCompletion, Exception] |
| 22 | + MockChatCompletionChunk = Union[chat.ChatCompletionChunk, Exception] |
| 23 | + MockResponse = Union[responses.Response, Exception] |
| 24 | + MockResponseStreamEvent = Union[responses.ResponseStreamEvent, Exception] |
| 25 | + |
| 26 | + |
| 27 | +@dataclass |
| 28 | +class MockOpenAI: |
| 29 | + completions: MockChatCompletion | Sequence[MockChatCompletion] | None = None |
| 30 | + stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]] | None = None |
| 31 | + index: int = 0 |
| 32 | + chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list) |
| 33 | + |
| 34 | + @cached_property |
| 35 | + def chat(self) -> Any: |
| 36 | + chat_completions = type('Completions', (), {'create': self.chat_completions_create}) |
| 37 | + return type('Chat', (), {'completions': chat_completions}) |
| 38 | + |
| 39 | + @classmethod |
| 40 | + def create_mock(cls, completions: MockChatCompletion | Sequence[MockChatCompletion]) -> AsyncOpenAI: |
| 41 | + return cast(AsyncOpenAI, cls(completions=completions)) |
| 42 | + |
| 43 | + @classmethod |
| 44 | + def create_mock_stream( |
| 45 | + cls, |
| 46 | + stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]], |
| 47 | + ) -> AsyncOpenAI: |
| 48 | + return cast(AsyncOpenAI, cls(stream=stream)) |
| 49 | + |
| 50 | + async def chat_completions_create( # pragma: lax no cover |
| 51 | + self, *_args: Any, stream: bool = False, **kwargs: Any |
| 52 | + ) -> chat.ChatCompletion | MockAsyncStream[MockChatCompletionChunk]: |
| 53 | + self.chat_completion_kwargs.append({k: v for k, v in kwargs.items() if v is not NOT_GIVEN}) |
| 54 | + |
| 55 | + if stream: |
| 56 | + assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided' |
| 57 | + if isinstance(self.stream[0], Sequence): |
| 58 | + response = MockAsyncStream(iter(cast(list[MockChatCompletionChunk], self.stream[self.index]))) |
| 59 | + else: |
| 60 | + response = MockAsyncStream(iter(cast(list[MockChatCompletionChunk], self.stream))) |
| 61 | + else: |
| 62 | + assert self.completions is not None, 'you can only used `stream=False` if `completions` are provided' |
| 63 | + if isinstance(self.completions, Sequence): |
| 64 | + raise_if_exception(self.completions[self.index]) |
| 65 | + response = cast(chat.ChatCompletion, self.completions[self.index]) |
| 66 | + else: |
| 67 | + raise_if_exception(self.completions) |
| 68 | + response = cast(chat.ChatCompletion, self.completions) |
| 69 | + self.index += 1 |
| 70 | + return response |
| 71 | + |
| 72 | + |
| 73 | +def get_mock_chat_completion_kwargs(async_open_ai: AsyncOpenAI) -> list[dict[str, Any]]: |
| 74 | + if isinstance(async_open_ai, MockOpenAI): |
| 75 | + return async_open_ai.chat_completion_kwargs |
| 76 | + else: # pragma: no cover |
| 77 | + raise RuntimeError('Not a MockOpenAI instance') |
| 78 | + |
| 79 | + |
| 80 | +def completion_message( |
| 81 | + message: ChatCompletionMessage, *, usage: CompletionUsage | None = None, logprobs: ChoiceLogprobs | None = None |
| 82 | +) -> chat.ChatCompletion: |
| 83 | + choices = [Choice(finish_reason='stop', index=0, message=message)] |
| 84 | + if logprobs: |
| 85 | + choices = [Choice(finish_reason='stop', index=0, message=message, logprobs=logprobs)] |
| 86 | + return chat.ChatCompletion( |
| 87 | + id='123', |
| 88 | + choices=choices, |
| 89 | + created=1704067200, # 2024-01-01 |
| 90 | + model='gpt-4o-123', |
| 91 | + object='chat.completion', |
| 92 | + usage=usage, |
| 93 | + ) |
| 94 | + |
| 95 | + |
| 96 | +@dataclass |
| 97 | +class MockOpenAIResponses: |
| 98 | + response: MockResponse | Sequence[MockResponse] | None = None |
| 99 | + stream: Sequence[MockResponseStreamEvent] | Sequence[Sequence[MockResponseStreamEvent]] | None = None |
| 100 | + index: int = 0 |
| 101 | + response_kwargs: list[dict[str, Any]] = field(default_factory=list) |
| 102 | + |
| 103 | + @cached_property |
| 104 | + def responses(self) -> Any: |
| 105 | + return type('Responses', (), {'create': self.responses_create}) |
| 106 | + |
| 107 | + @classmethod |
| 108 | + def create_mock(cls, responses: MockResponse | Sequence[MockResponse]) -> AsyncOpenAI: |
| 109 | + return cast(AsyncOpenAI, cls(response=responses)) |
| 110 | + |
| 111 | + @classmethod |
| 112 | + def create_mock_stream( |
| 113 | + cls, |
| 114 | + stream: Sequence[MockResponseStreamEvent] | Sequence[Sequence[MockResponseStreamEvent]], |
| 115 | + ) -> AsyncOpenAI: |
| 116 | + return cast(AsyncOpenAI, cls(stream=stream)) # pragma: lax no cover |
| 117 | + |
| 118 | + async def responses_create( # pragma: lax no cover |
| 119 | + self, *_args: Any, stream: bool = False, **kwargs: Any |
| 120 | + ) -> responses.Response | MockAsyncStream[MockResponseStreamEvent]: |
| 121 | + self.response_kwargs.append({k: v for k, v in kwargs.items() if v is not NOT_GIVEN}) |
| 122 | + |
| 123 | + if stream: |
| 124 | + assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided' |
| 125 | + if isinstance(self.stream[0], Sequence): |
| 126 | + response = MockAsyncStream(iter(cast(list[MockResponseStreamEvent], self.stream[self.index]))) |
| 127 | + else: |
| 128 | + response = MockAsyncStream(iter(cast(list[MockResponseStreamEvent], self.stream))) |
| 129 | + else: |
| 130 | + assert self.response is not None, 'you can only used `stream=False` if `response` are provided' |
| 131 | + if isinstance(self.response, Sequence): |
| 132 | + raise_if_exception(self.response[self.index]) |
| 133 | + response = cast(responses.Response, self.response[self.index]) |
| 134 | + else: |
| 135 | + raise_if_exception(self.response) |
| 136 | + response = cast(responses.Response, self.response) |
| 137 | + self.index += 1 |
| 138 | + return response |
| 139 | + |
| 140 | + |
| 141 | +def get_mock_responses_kwargs(async_open_ai: AsyncOpenAI) -> list[dict[str, Any]]: |
| 142 | + if isinstance(async_open_ai, MockOpenAIResponses): # pragma: lax no cover |
| 143 | + return async_open_ai.response_kwargs |
| 144 | + else: # pragma: no cover |
| 145 | + raise RuntimeError('Not a MockOpenAIResponses instance') |
| 146 | + |
| 147 | + |
| 148 | +def response_message( |
| 149 | + output_items: Sequence[ResponseOutputItem], *, usage: ResponseUsage | None = None |
| 150 | +) -> responses.Response: |
| 151 | + return responses.Response( |
| 152 | + id='123', |
| 153 | + model='gpt-4o-123', |
| 154 | + object='response', |
| 155 | + created_at=1704067200, # 2024-01-01 |
| 156 | + output=list(output_items), |
| 157 | + parallel_tool_calls=True, |
| 158 | + tool_choice='auto', |
| 159 | + tools=[], |
| 160 | + usage=usage, |
| 161 | + ) |
0 commit comments