Skip to content

Commit 04f71d7

Browse files
mattfiamemilio
authored andcommitted
chore(recorder): update mocks to be closer to non-mock environment (llamastack#3442)
# What does this PR do? the @required_args decorator in openai-python is masking the async nature of the {AsyncCompletions,chat.AsyncCompletions}.create method. see openai/openai-python#996 this means two things - 0. we cannot use iscoroutine in the recorder to detect async vs non 1. our mocks are inappropriately introducing identifiable async for (0), we update the iscoroutine check w/ detection of /v1/models, which is the only non-async function we mock & record. for (1), we could leave everything as is and assume (0) will catch errors. to be defensive, we update the unit tests to mock below create methods, allowing the true openai-python create() methods to be tested.
1 parent 9f42bbe commit 04f71d7

File tree

2 files changed

+117
-113
lines changed

2 files changed

+117
-113
lines changed

llama_stack/testing/inference_recorder.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from __future__ import annotations # for forward references
88

99
import hashlib
10-
import inspect
1110
import json
1211
import os
1312
from collections.abc import Generator
@@ -243,11 +242,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
243242
global _current_mode, _current_storage
244243

245244
if _current_mode == InferenceMode.LIVE or _current_storage is None:
246-
# Normal operation
247-
if inspect.iscoroutinefunction(original_method):
248-
return await original_method(self, *args, **kwargs)
249-
else:
245+
if endpoint == "/v1/models":
250246
return original_method(self, *args, **kwargs)
247+
else:
248+
return await original_method(self, *args, **kwargs)
251249

252250
# Get base URL based on client type
253251
if client_type == "openai":
@@ -298,10 +296,10 @@ async def replay_stream():
298296
)
299297

300298
elif _current_mode == InferenceMode.RECORD:
301-
if inspect.iscoroutinefunction(original_method):
302-
response = await original_method(self, *args, **kwargs)
303-
else:
299+
if endpoint == "/v1/models":
304300
response = original_method(self, *args, **kwargs)
301+
else:
302+
response = await original_method(self, *args, **kwargs)
305303

306304
# we want to store the result of the iterator, not the iterator itself
307305
if endpoint == "/v1/models":

tests/unit/distribution/test_inference_recordings.py

Lines changed: 111 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -155,71 +155,61 @@ def test_response_storage(self, temp_storage_dir):
155155

156156
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
157157
"""Test that recording mode captures and stores responses."""
158-
159-
async def mock_create(*args, **kwargs):
160-
return real_openai_chat_response
161-
162158
temp_storage_dir = temp_storage_dir / "test_recording_mode"
163-
with patch(
164-
"openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create
165-
):
166-
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
167-
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
168-
169-
response = await client.chat.completions.create(
170-
model="llama3.2:3b",
171-
messages=[{"role": "user", "content": "Hello, how are you?"}],
172-
temperature=0.7,
173-
max_tokens=50,
174-
user=NOT_GIVEN,
175-
)
159+
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
160+
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
161+
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
162+
163+
response = await client.chat.completions.create(
164+
model="llama3.2:3b",
165+
messages=[{"role": "user", "content": "Hello, how are you?"}],
166+
temperature=0.7,
167+
max_tokens=50,
168+
user=NOT_GIVEN,
169+
)
176170

177-
# Verify the response was returned correctly
178-
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
171+
# Verify the response was returned correctly
172+
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
173+
client.chat.completions._post.assert_called_once()
179174

180175
# Verify recording was stored
181176
storage = ResponseStorage(temp_storage_dir)
182177
assert storage.responses_dir.exists()
183178

184179
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
185180
"""Test that replay mode returns stored responses without making real calls."""
186-
187-
async def mock_create(*args, **kwargs):
188-
return real_openai_chat_response
189-
190181
temp_storage_dir = temp_storage_dir / "test_replay_mode"
191182
# First, record a response
192-
with patch(
193-
"openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create
194-
):
195-
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
196-
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
197-
198-
response = await client.chat.completions.create(
199-
model="llama3.2:3b",
200-
messages=[{"role": "user", "content": "Hello, how are you?"}],
201-
temperature=0.7,
202-
max_tokens=50,
203-
user=NOT_GIVEN,
204-
)
183+
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
184+
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
185+
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
186+
187+
response = await client.chat.completions.create(
188+
model="llama3.2:3b",
189+
messages=[{"role": "user", "content": "Hello, how are you?"}],
190+
temperature=0.7,
191+
max_tokens=50,
192+
user=NOT_GIVEN,
193+
)
194+
client.chat.completions._post.assert_called_once()
205195

206196
# Now test replay mode - should not call the original method
207-
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
208-
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
209-
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
197+
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
198+
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
199+
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
210200

211-
response = await client.chat.completions.create(
212-
model="llama3.2:3b",
213-
messages=[{"role": "user", "content": "Hello, how are you?"}],
214-
temperature=0.7,
215-
max_tokens=50,
216-
)
201+
response = await client.chat.completions.create(
202+
model="llama3.2:3b",
203+
messages=[{"role": "user", "content": "Hello, how are you?"}],
204+
temperature=0.7,
205+
max_tokens=50,
206+
)
217207

218-
# Verify we got the recorded response
219-
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
208+
# Verify we got the recorded response
209+
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
220210

221-
# Verify the original method was NOT called
222-
mock_create_patch.assert_not_called()
211+
# Verify the original method was NOT called
212+
client.chat.completions._post.assert_not_called()
223213

224214
async def test_replay_mode_models(self, temp_storage_dir):
225215
"""Test that replay mode returns stored responses without making real model listing calls."""
@@ -272,43 +262,50 @@ async def test_replay_missing_recording(self, temp_storage_dir):
272262
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
273263
"""Test recording and replay of embeddings calls."""
274264

275-
async def mock_create(*args, **kwargs):
276-
return real_embeddings_response
265+
# baseline - mock works without recording
266+
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
267+
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
268+
response = await client.embeddings.create(
269+
model=real_embeddings_response.model,
270+
input=["Hello world", "Test embedding"],
271+
encoding_format=NOT_GIVEN,
272+
)
273+
assert len(response.data) == 2
274+
assert response.data[0].embedding == [0.1, 0.2, 0.3]
275+
client.embeddings._post.assert_called_once()
277276

278277
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
279278
# Record
280-
with patch(
281-
"openai.resources.embeddings.AsyncEmbeddings.create", new_callable=AsyncMock, side_effect=mock_create
282-
):
283-
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
284-
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
285-
286-
response = await client.embeddings.create(
287-
model=real_embeddings_response.model,
288-
input=["Hello world", "Test embedding"],
289-
encoding_format=NOT_GIVEN,
290-
dimensions=NOT_GIVEN,
291-
user=NOT_GIVEN,
292-
)
279+
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
280+
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
281+
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
282+
283+
response = await client.embeddings.create(
284+
model=real_embeddings_response.model,
285+
input=["Hello world", "Test embedding"],
286+
encoding_format=NOT_GIVEN,
287+
dimensions=NOT_GIVEN,
288+
user=NOT_GIVEN,
289+
)
293290

294-
assert len(response.data) == 2
291+
assert len(response.data) == 2
295292

296293
# Replay
297-
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
298-
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
299-
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
294+
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
295+
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
296+
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
300297

301-
response = await client.embeddings.create(
302-
model=real_embeddings_response.model,
303-
input=["Hello world", "Test embedding"],
304-
)
298+
response = await client.embeddings.create(
299+
model=real_embeddings_response.model,
300+
input=["Hello world", "Test embedding"],
301+
)
305302

306-
# Verify we got the recorded response
307-
assert len(response.data) == 2
308-
assert response.data[0].embedding == [0.1, 0.2, 0.3]
303+
# Verify we got the recorded response
304+
assert len(response.data) == 2
305+
assert response.data[0].embedding == [0.1, 0.2, 0.3]
309306

310-
# Verify original method was not called
311-
mock_create_patch.assert_not_called()
307+
# Verify original method was not called
308+
client.embeddings._post.assert_not_called()
312309

313310
async def test_completions_recording(self, temp_storage_dir):
314311
real_completions_response = OpenAICompletion(
@@ -326,40 +323,49 @@ async def test_completions_recording(self, temp_storage_dir):
326323
],
327324
)
328325

329-
async def mock_create(*args, **kwargs):
330-
return real_completions_response
331-
332326
temp_storage_dir = temp_storage_dir / "test_completions_recording"
333327

334-
# Record
335-
with patch(
336-
"openai.resources.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create
337-
):
338-
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
339-
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
328+
# baseline - mock works without recording
329+
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
330+
client.completions._post = AsyncMock(return_value=real_completions_response)
331+
response = await client.completions.create(
332+
model=real_completions_response.model,
333+
prompt="Hello, how are you?",
334+
temperature=0.7,
335+
max_tokens=50,
336+
user=NOT_GIVEN,
337+
)
338+
assert response.choices[0].text == real_completions_response.choices[0].text
339+
client.completions._post.assert_called_once()
340340

341-
response = await client.completions.create(
342-
model=real_completions_response.model,
343-
prompt="Hello, how are you?",
344-
temperature=0.7,
345-
max_tokens=50,
346-
user=NOT_GIVEN,
347-
)
341+
# Record
342+
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
343+
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
344+
client.completions._post = AsyncMock(return_value=real_completions_response)
345+
346+
response = await client.completions.create(
347+
model=real_completions_response.model,
348+
prompt="Hello, how are you?",
349+
temperature=0.7,
350+
max_tokens=50,
351+
user=NOT_GIVEN,
352+
)
348353

349-
assert response.choices[0].text == real_completions_response.choices[0].text
354+
assert response.choices[0].text == real_completions_response.choices[0].text
355+
client.completions._post.assert_called_once()
350356

351357
# Replay
352-
with patch("openai.resources.completions.AsyncCompletions.create") as mock_create_patch:
353-
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
354-
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
355-
response = await client.completions.create(
356-
model=real_completions_response.model,
357-
prompt="Hello, how are you?",
358-
temperature=0.7,
359-
max_tokens=50,
360-
)
361-
assert response.choices[0].text == real_completions_response.choices[0].text
362-
mock_create_patch.assert_not_called()
358+
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
359+
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
360+
client.completions._post = AsyncMock(return_value=real_completions_response)
361+
response = await client.completions.create(
362+
model=real_completions_response.model,
363+
prompt="Hello, how are you?",
364+
temperature=0.7,
365+
max_tokens=50,
366+
)
367+
assert response.choices[0].text == real_completions_response.choices[0].text
368+
client.completions._post.assert_not_called()
363369

364370
async def test_live_mode(self, real_openai_chat_response):
365371
"""Test that live mode passes through to original methods."""

0 commit comments

Comments
 (0)