Skip to content

Commit 2449e54

Browse files
mattfiamemilio
authored andcommitted
chore(recorder): add support for NOT_GIVEN (llamastack#3430)
# What does this PR do? the recorder mocks the openai-python interface. the openai-python interface allows NOT_GIVEN as an input option. this change properly handles NOT_GIVEN. ## Test Plan ci (coverage for chat, completions, embeddings)
1 parent 1abe8a3 commit 2449e54

File tree

2 files changed

+67
-3
lines changed

2 files changed

+67
-3
lines changed

llama_stack/testing/inference_recorder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from pathlib import Path
1717
from typing import Any, Literal, cast
1818

19+
from openai import NOT_GIVEN
20+
1921
from llama_stack.log import get_logger
2022

2123
logger = get_logger(__name__, category="testing")
@@ -250,6 +252,9 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
250252
# Get base URL based on client type
251253
if client_type == "openai":
252254
base_url = str(self._client.base_url)
255+
256+
# the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out
257+
kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN}
253258
elif client_type == "ollama":
254259
# Get base URL from the client (Ollama client uses host attribute)
255260
base_url = getattr(self, "host", "http://localhost:11434")

tests/unit/distribution/test_inference_recordings.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
from unittest.mock import AsyncMock, Mock, patch
1010

1111
import pytest
12-
from openai import AsyncOpenAI
12+
from openai import NOT_GIVEN, AsyncOpenAI
1313
from openai.types.model import Model as OpenAIModel
1414

1515
# Import the real Pydantic response types instead of using Mocks
1616
from llama_stack.apis.inference import (
1717
OpenAIAssistantMessageParam,
1818
OpenAIChatCompletion,
1919
OpenAIChoice,
20+
OpenAICompletion,
2021
OpenAIEmbeddingData,
2122
OpenAIEmbeddingsResponse,
2223
OpenAIEmbeddingUsage,
@@ -170,6 +171,7 @@ async def mock_create(*args, **kwargs):
170171
messages=[{"role": "user", "content": "Hello, how are you?"}],
171172
temperature=0.7,
172173
max_tokens=50,
174+
user=NOT_GIVEN,
173175
)
174176

175177
# Verify the response was returned correctly
@@ -198,6 +200,7 @@ async def mock_create(*args, **kwargs):
198200
messages=[{"role": "user", "content": "Hello, how are you?"}],
199201
temperature=0.7,
200202
max_tokens=50,
203+
user=NOT_GIVEN,
201204
)
202205

203206
# Now test replay mode - should not call the original method
@@ -281,7 +284,11 @@ async def mock_create(*args, **kwargs):
281284
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
282285

283286
response = await client.embeddings.create(
284-
model="nomic-embed-text", input=["Hello world", "Test embedding"]
287+
model=real_embeddings_response.model,
288+
input=["Hello world", "Test embedding"],
289+
encoding_format=NOT_GIVEN,
290+
dimensions=NOT_GIVEN,
291+
user=NOT_GIVEN,
285292
)
286293

287294
assert len(response.data) == 2
@@ -292,7 +299,8 @@ async def mock_create(*args, **kwargs):
292299
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
293300

294301
response = await client.embeddings.create(
295-
model="nomic-embed-text", input=["Hello world", "Test embedding"]
302+
model=real_embeddings_response.model,
303+
input=["Hello world", "Test embedding"],
296304
)
297305

298306
# Verify we got the recorded response
@@ -302,6 +310,57 @@ async def mock_create(*args, **kwargs):
302310
# Verify original method was not called
303311
mock_create_patch.assert_not_called()
304312

313+
async def test_completions_recording(self, temp_storage_dir):
314+
real_completions_response = OpenAICompletion(
315+
id="test_completion",
316+
object="text_completion",
317+
created=1234567890,
318+
model="llama3.2:3b",
319+
choices=[
320+
{
321+
"text": "Hello! I'm doing well, thank you for asking.",
322+
"index": 0,
323+
"logprobs": None,
324+
"finish_reason": "stop",
325+
}
326+
],
327+
)
328+
329+
async def mock_create(*args, **kwargs):
330+
return real_completions_response
331+
332+
temp_storage_dir = temp_storage_dir / "test_completions_recording"
333+
334+
# Record
335+
with patch(
336+
"openai.resources.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create
337+
):
338+
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
339+
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
340+
341+
response = await client.completions.create(
342+
model=real_completions_response.model,
343+
prompt="Hello, how are you?",
344+
temperature=0.7,
345+
max_tokens=50,
346+
user=NOT_GIVEN,
347+
)
348+
349+
assert response.choices[0].text == real_completions_response.choices[0].text
350+
351+
# Replay
352+
with patch("openai.resources.completions.AsyncCompletions.create") as mock_create_patch:
353+
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
354+
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
355+
response = await client.completions.create(
356+
model=real_completions_response.model,
357+
prompt="Hello, how are you?",
358+
temperature=0.7,
359+
max_tokens=50,
360+
)
361+
assert response.choices[0].text == real_completions_response.choices[0].text
362+
mock_create_patch.assert_not_called()
363+
305364
async def test_live_mode(self, real_openai_chat_response):
306365
"""Test that live mode passes through to original methods."""
307366

0 commit comments

Comments
 (0)