Skip to content

Commit 6e1b330

Browse files
authored
Add MockOpenAIResponses for use in tests (#2702)
1 parent 9346422 commit 6e1b330

File tree

4 files changed

+208
-178
lines changed

4 files changed

+208
-178
lines changed

tests/models/cassettes/test_openai_responses/test_openai_responses_usage_without_tokens_details.yaml

Lines changed: 0 additions & 96 deletions
This file was deleted.

tests/models/mock_openai.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
from __future__ import annotations as _annotations
2+
3+
from collections.abc import Sequence
4+
from dataclasses import dataclass, field
5+
from functools import cached_property
6+
from typing import Any, Union, cast
7+
8+
from ..conftest import raise_if_exception, try_import
9+
from .mock_async_stream import MockAsyncStream
10+
11+
with try_import() as imports_successful:
12+
from openai import NOT_GIVEN, AsyncOpenAI
13+
from openai.types import chat, responses
14+
from openai.types.chat.chat_completion import Choice, ChoiceLogprobs
15+
from openai.types.chat.chat_completion_message import ChatCompletionMessage
16+
from openai.types.completion_usage import CompletionUsage
17+
from openai.types.responses.response import ResponseUsage
18+
from openai.types.responses.response_output_item import ResponseOutputItem
19+
20+
# note: we use Union here so that casting works with Python 3.9
21+
MockChatCompletion = Union[chat.ChatCompletion, Exception]
22+
MockChatCompletionChunk = Union[chat.ChatCompletionChunk, Exception]
23+
MockResponse = Union[responses.Response, Exception]
24+
MockResponseStreamEvent = Union[responses.ResponseStreamEvent, Exception]
25+
26+
27+
@dataclass
28+
class MockOpenAI:
29+
completions: MockChatCompletion | Sequence[MockChatCompletion] | None = None
30+
stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]] | None = None
31+
index: int = 0
32+
chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list)
33+
34+
@cached_property
35+
def chat(self) -> Any:
36+
chat_completions = type('Completions', (), {'create': self.chat_completions_create})
37+
return type('Chat', (), {'completions': chat_completions})
38+
39+
@classmethod
40+
def create_mock(cls, completions: MockChatCompletion | Sequence[MockChatCompletion]) -> AsyncOpenAI:
41+
return cast(AsyncOpenAI, cls(completions=completions))
42+
43+
@classmethod
44+
def create_mock_stream(
45+
cls,
46+
stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]],
47+
) -> AsyncOpenAI:
48+
return cast(AsyncOpenAI, cls(stream=stream))
49+
50+
async def chat_completions_create( # pragma: lax no cover
51+
self, *_args: Any, stream: bool = False, **kwargs: Any
52+
) -> chat.ChatCompletion | MockAsyncStream[MockChatCompletionChunk]:
53+
self.chat_completion_kwargs.append({k: v for k, v in kwargs.items() if v is not NOT_GIVEN})
54+
55+
if stream:
56+
assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided'
57+
if isinstance(self.stream[0], Sequence):
58+
response = MockAsyncStream(iter(cast(list[MockChatCompletionChunk], self.stream[self.index])))
59+
else:
60+
response = MockAsyncStream(iter(cast(list[MockChatCompletionChunk], self.stream)))
61+
else:
62+
assert self.completions is not None, 'you can only used `stream=False` if `completions` are provided'
63+
if isinstance(self.completions, Sequence):
64+
raise_if_exception(self.completions[self.index])
65+
response = cast(chat.ChatCompletion, self.completions[self.index])
66+
else:
67+
raise_if_exception(self.completions)
68+
response = cast(chat.ChatCompletion, self.completions)
69+
self.index += 1
70+
return response
71+
72+
73+
def get_mock_chat_completion_kwargs(async_open_ai: AsyncOpenAI) -> list[dict[str, Any]]:
74+
if isinstance(async_open_ai, MockOpenAI):
75+
return async_open_ai.chat_completion_kwargs
76+
else: # pragma: no cover
77+
raise RuntimeError('Not a MockOpenAI instance')
78+
79+
80+
def completion_message(
81+
message: ChatCompletionMessage, *, usage: CompletionUsage | None = None, logprobs: ChoiceLogprobs | None = None
82+
) -> chat.ChatCompletion:
83+
choices = [Choice(finish_reason='stop', index=0, message=message)]
84+
if logprobs:
85+
choices = [Choice(finish_reason='stop', index=0, message=message, logprobs=logprobs)]
86+
return chat.ChatCompletion(
87+
id='123',
88+
choices=choices,
89+
created=1704067200, # 2024-01-01
90+
model='gpt-4o-123',
91+
object='chat.completion',
92+
usage=usage,
93+
)
94+
95+
96+
@dataclass
97+
class MockOpenAIResponses:
98+
response: MockResponse | Sequence[MockResponse] | None = None
99+
stream: Sequence[MockResponseStreamEvent] | Sequence[Sequence[MockResponseStreamEvent]] | None = None
100+
index: int = 0
101+
response_kwargs: list[dict[str, Any]] = field(default_factory=list)
102+
103+
@cached_property
104+
def responses(self) -> Any:
105+
return type('Responses', (), {'create': self.responses_create})
106+
107+
@classmethod
108+
def create_mock(cls, responses: MockResponse | Sequence[MockResponse]) -> AsyncOpenAI:
109+
return cast(AsyncOpenAI, cls(response=responses))
110+
111+
@classmethod
112+
def create_mock_stream(
113+
cls,
114+
stream: Sequence[MockResponseStreamEvent] | Sequence[Sequence[MockResponseStreamEvent]],
115+
) -> AsyncOpenAI:
116+
return cast(AsyncOpenAI, cls(stream=stream)) # pragma: lax no cover
117+
118+
async def responses_create( # pragma: lax no cover
119+
self, *_args: Any, stream: bool = False, **kwargs: Any
120+
) -> responses.Response | MockAsyncStream[MockResponseStreamEvent]:
121+
self.response_kwargs.append({k: v for k, v in kwargs.items() if v is not NOT_GIVEN})
122+
123+
if stream:
124+
assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided'
125+
if isinstance(self.stream[0], Sequence):
126+
response = MockAsyncStream(iter(cast(list[MockResponseStreamEvent], self.stream[self.index])))
127+
else:
128+
response = MockAsyncStream(iter(cast(list[MockResponseStreamEvent], self.stream)))
129+
else:
130+
assert self.response is not None, 'you can only used `stream=False` if `response` are provided'
131+
if isinstance(self.response, Sequence):
132+
raise_if_exception(self.response[self.index])
133+
response = cast(responses.Response, self.response[self.index])
134+
else:
135+
raise_if_exception(self.response)
136+
response = cast(responses.Response, self.response)
137+
self.index += 1
138+
return response
139+
140+
141+
def get_mock_responses_kwargs(async_open_ai: AsyncOpenAI) -> list[dict[str, Any]]:
142+
if isinstance(async_open_ai, MockOpenAIResponses): # pragma: lax no cover
143+
return async_open_ai.response_kwargs
144+
else: # pragma: no cover
145+
raise RuntimeError('Not a MockOpenAIResponses instance')
146+
147+
148+
def response_message(
149+
output_items: Sequence[ResponseOutputItem], *, usage: ResponseUsage | None = None
150+
) -> responses.Response:
151+
return responses.Response(
152+
id='123',
153+
model='gpt-4o-123',
154+
object='response',
155+
created_at=1704067200, # 2024-01-01
156+
output=list(output_items),
157+
parallel_tool_calls=True,
158+
tool_choice='auto',
159+
tools=[],
160+
usage=usage,
161+
)

tests/models/test_openai.py

Lines changed: 5 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from __future__ import annotations as _annotations
22

33
import json
4-
from collections.abc import Sequence
5-
from dataclasses import dataclass, field
4+
from dataclasses import dataclass
65
from datetime import datetime, timezone
76
from enum import Enum
8-
from functools import cached_property
97
from typing import Annotated, Any, Callable, Literal, Union, cast
108

119
import httpx
@@ -47,13 +45,13 @@
4745
from pydantic_ai.tools import ToolDefinition
4846
from pydantic_ai.usage import RequestUsage
4947

50-
from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, TestEnv, raise_if_exception, try_import
51-
from .mock_async_stream import MockAsyncStream
48+
from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, TestEnv, try_import
49+
from .mock_openai import MockOpenAI, completion_message, get_mock_chat_completion_kwargs
5250

5351
with try_import() as imports_successful:
54-
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI
52+
from openai import APIStatusError, AsyncOpenAI
5553
from openai.types import chat
56-
from openai.types.chat.chat_completion import Choice, ChoiceLogprobs
54+
from openai.types.chat.chat_completion import ChoiceLogprobs
5755
from openai.types.chat.chat_completion_chunk import (
5856
Choice as ChunkChoice,
5957
ChoiceDelta,
@@ -98,75 +96,6 @@ def test_init():
9896
assert m.model_name == 'gpt-4o'
9997

10098

101-
@dataclass
102-
class MockOpenAI:
103-
completions: MockChatCompletion | Sequence[MockChatCompletion] | None = None
104-
stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]] | None = None
105-
index: int = 0
106-
chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list)
107-
108-
@cached_property
109-
def chat(self) -> Any:
110-
chat_completions = type('Completions', (), {'create': self.chat_completions_create})
111-
return type('Chat', (), {'completions': chat_completions})
112-
113-
@classmethod
114-
def create_mock(cls, completions: MockChatCompletion | Sequence[MockChatCompletion]) -> AsyncOpenAI:
115-
return cast(AsyncOpenAI, cls(completions=completions))
116-
117-
@classmethod
118-
def create_mock_stream(
119-
cls,
120-
stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]],
121-
) -> AsyncOpenAI:
122-
return cast(AsyncOpenAI, cls(stream=stream))
123-
124-
async def chat_completions_create( # pragma: lax no cover
125-
self, *_args: Any, stream: bool = False, **kwargs: Any
126-
) -> chat.ChatCompletion | MockAsyncStream[MockChatCompletionChunk]:
127-
self.chat_completion_kwargs.append({k: v for k, v in kwargs.items() if v is not NOT_GIVEN})
128-
129-
if stream:
130-
assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided'
131-
if isinstance(self.stream[0], Sequence):
132-
response = MockAsyncStream(iter(cast(list[MockChatCompletionChunk], self.stream[self.index])))
133-
else:
134-
response = MockAsyncStream(iter(cast(list[MockChatCompletionChunk], self.stream)))
135-
else:
136-
assert self.completions is not None, 'you can only used `stream=False` if `completions` are provided'
137-
if isinstance(self.completions, Sequence):
138-
raise_if_exception(self.completions[self.index])
139-
response = cast(chat.ChatCompletion, self.completions[self.index])
140-
else:
141-
raise_if_exception(self.completions)
142-
response = cast(chat.ChatCompletion, self.completions)
143-
self.index += 1
144-
return response
145-
146-
147-
def get_mock_chat_completion_kwargs(async_open_ai: AsyncOpenAI) -> list[dict[str, Any]]:
148-
if isinstance(async_open_ai, MockOpenAI):
149-
return async_open_ai.chat_completion_kwargs
150-
else: # pragma: no cover
151-
raise RuntimeError('Not a MockOpenAI instance')
152-
153-
154-
def completion_message(
155-
message: ChatCompletionMessage, *, usage: CompletionUsage | None = None, logprobs: ChoiceLogprobs | None = None
156-
) -> chat.ChatCompletion:
157-
choices = [Choice(finish_reason='stop', index=0, message=message)]
158-
if logprobs:
159-
choices = [Choice(finish_reason='stop', index=0, message=message, logprobs=logprobs)]
160-
return chat.ChatCompletion(
161-
id='123',
162-
choices=choices,
163-
created=1704067200, # 2024-01-01
164-
model='gpt-4o-123',
165-
object='chat.completion',
166-
usage=usage,
167-
)
168-
169-
17099
async def test_request_simple_success(allow_model_requests: None):
171100
c = completion_message(
172101
ChatCompletionMessage(content='world', role='assistant'),

0 commit comments

Comments
 (0)