99from unittest .mock import AsyncMock , Mock , patch
1010
1111import pytest
12- from openai import AsyncOpenAI
12+ from openai import NOT_GIVEN , AsyncOpenAI
1313from openai .types .model import Model as OpenAIModel
1414
1515# Import the real Pydantic response types instead of using Mocks
1616from llama_stack .apis .inference import (
1717 OpenAIAssistantMessageParam ,
1818 OpenAIChatCompletion ,
1919 OpenAIChoice ,
20+ OpenAICompletion ,
2021 OpenAIEmbeddingData ,
2122 OpenAIEmbeddingsResponse ,
2223 OpenAIEmbeddingUsage ,
@@ -170,6 +171,7 @@ async def mock_create(*args, **kwargs):
170171 messages = [{"role" : "user" , "content" : "Hello, how are you?" }],
171172 temperature = 0.7 ,
172173 max_tokens = 50 ,
174+ user = NOT_GIVEN ,
173175 )
174176
175177 # Verify the response was returned correctly
@@ -198,6 +200,7 @@ async def mock_create(*args, **kwargs):
198200 messages = [{"role" : "user" , "content" : "Hello, how are you?" }],
199201 temperature = 0.7 ,
200202 max_tokens = 50 ,
203+ user = NOT_GIVEN ,
201204 )
202205
203206 # Now test replay mode - should not call the original method
@@ -281,7 +284,11 @@ async def mock_create(*args, **kwargs):
281284 client = AsyncOpenAI (base_url = "http://localhost:11434/v1" , api_key = "test" )
282285
283286 response = await client .embeddings .create (
284- model = "nomic-embed-text" , input = ["Hello world" , "Test embedding" ]
287+ model = real_embeddings_response .model ,
288+ input = ["Hello world" , "Test embedding" ],
289+ encoding_format = NOT_GIVEN ,
290+ dimensions = NOT_GIVEN ,
291+ user = NOT_GIVEN ,
285292 )
286293
287294 assert len (response .data ) == 2
@@ -292,7 +299,8 @@ async def mock_create(*args, **kwargs):
292299 client = AsyncOpenAI (base_url = "http://localhost:11434/v1" , api_key = "test" )
293300
294301 response = await client .embeddings .create (
295- model = "nomic-embed-text" , input = ["Hello world" , "Test embedding" ]
302+ model = real_embeddings_response .model ,
303+ input = ["Hello world" , "Test embedding" ],
296304 )
297305
298306 # Verify we got the recorded response
@@ -302,6 +310,57 @@ async def mock_create(*args, **kwargs):
302310 # Verify original method was not called
303311 mock_create_patch .assert_not_called ()
304312
313+ async def test_completions_recording (self , temp_storage_dir ):
314+ real_completions_response = OpenAICompletion (
315+ id = "test_completion" ,
316+ object = "text_completion" ,
317+ created = 1234567890 ,
318+ model = "llama3.2:3b" ,
319+ choices = [
320+ {
321+ "text" : "Hello! I'm doing well, thank you for asking." ,
322+ "index" : 0 ,
323+ "logprobs" : None ,
324+ "finish_reason" : "stop" ,
325+ }
326+ ],
327+ )
328+
329+ async def mock_create (* args , ** kwargs ):
330+ return real_completions_response
331+
332+ temp_storage_dir = temp_storage_dir / "test_completions_recording"
333+
334+ # Record
335+ with patch (
336+ "openai.resources.completions.AsyncCompletions.create" , new_callable = AsyncMock , side_effect = mock_create
337+ ):
338+ with inference_recording (mode = InferenceMode .RECORD , storage_dir = str (temp_storage_dir )):
339+ client = AsyncOpenAI (base_url = "http://localhost:11434/v1" , api_key = "test" )
340+
341+ response = await client .completions .create (
342+ model = real_completions_response .model ,
343+ prompt = "Hello, how are you?" ,
344+ temperature = 0.7 ,
345+ max_tokens = 50 ,
346+ user = NOT_GIVEN ,
347+ )
348+
349+ assert response .choices [0 ].text == real_completions_response .choices [0 ].text
350+
351+ # Replay
352+ with patch ("openai.resources.completions.AsyncCompletions.create" ) as mock_create_patch :
353+ with inference_recording (mode = InferenceMode .REPLAY , storage_dir = str (temp_storage_dir )):
354+ client = AsyncOpenAI (base_url = "http://localhost:11434/v1" , api_key = "test" )
355+ response = await client .completions .create (
356+ model = real_completions_response .model ,
357+ prompt = "Hello, how are you?" ,
358+ temperature = 0.7 ,
359+ max_tokens = 50 ,
360+ )
361+ assert response .choices [0 ].text == real_completions_response .choices [0 ].text
362+ mock_create_patch .assert_not_called ()
363+
305364 async def test_live_mode (self , real_openai_chat_response ):
306365 """Test that live mode passes through to original methods."""
307366
0 commit comments