Skip to content

Commit cf1a2cb

Browse files
authored
[AI] test: Add support for recording requests from the openai package (#34424)
* test: Add support for recording requests sent through openai * chore: remove unused import * chore: Initialize assets.json * fix: Spelling error in docstring * chore: Add openai as dev_requirement for azure-ai-resources * test: Ensure enum is serialized as value
1 parent fa4b95f commit cf1a2cb

File tree

7 files changed

+227
-2
lines changed

7 files changed

+227
-2
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"AssetsRepo": "Azure/azure-sdk-assets",
3+
"AssetsRepoPrefixPath": "python",
4+
"TagPrefix": "python/ai/azure-ai-generative",
5+
"Tag": ""
6+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""Implementation of an httpx.Client that forwards traffic to the Azure SDK test-proxy.
2+
3+
.. note::
4+
5+
This module has side-effects!
6+
7+
Importing this module will replace the default httpx.Client used
8+
by the openai package with one that can redirect it's traffic
9+
to the Azure SDK test-proxy on demand.
10+
11+
"""
12+
from contextlib import contextmanager
13+
from typing import Iterable, Literal, Optional
14+
15+
import httpx
16+
import openai._base_client
17+
from typing_extensions import override
18+
from dataclasses import dataclass
19+
20+
21+
@dataclass
22+
class TestProxyConfig:
23+
recording_id: str
24+
"""The ID for the ongoing test recording."""
25+
26+
recording_mode: Literal["playback", "record"]
27+
"""The current recording mode."""
28+
29+
proxy_url: str
30+
"""The url for the Azure SDK test proxy."""
31+
32+
33+
class TestProxyHttpxClient(openai._base_client.SyncHttpxClientWrapper):
34+
recording_config: Optional[TestProxyConfig] = None
35+
36+
@classmethod
37+
def is_recording(cls) -> bool:
38+
"""Whether we are forwarding requests to the test proxy
39+
40+
:return: True if forwarding, False otherwise
41+
:rtype: bool
42+
"""
43+
return cls.recording_config is not None
44+
45+
@classmethod
46+
@contextmanager
47+
def record_with_proxy(cls, config: TestProxyConfig) -> Iterable[None]:
48+
"""Forward all requests made within the scope of context manager to test-proxy.
49+
50+
:param TestProxyConfig config: The test proxy configuration
51+
"""
52+
cls.recording_config = config
53+
54+
yield
55+
56+
cls.recording_config = None
57+
58+
@override
59+
def send(self, request: httpx.Request, **kwargs) -> httpx.Response:
60+
if self.is_recording():
61+
return self._send_to_proxy(request, **kwargs)
62+
else:
63+
return super().send(request, **kwargs)
64+
65+
def _send_to_proxy(self, request: httpx.Request, **kwargs) -> httpx.Response:
66+
"""Forwards a network request to the test proxy
67+
68+
:param httpx.Request request: The request to send
69+
:keyword **kwargs: The kwargs accepted by httpx.Client.send
70+
:return: The request's response
71+
:rtype: httpx.Response
72+
"""
73+
assert self.is_recording(), f"{self._send_to_proxy.__qualname__} should only be called while recording"
74+
config = self.recording_config
75+
original_url = request.url
76+
77+
request_path = original_url.copy_with(scheme="", netloc=b"")
78+
request.url = httpx.URL(config.proxy_url).join(request_path)
79+
80+
headers = request.headers
81+
if headers.get("x-recording-upstream-base-uri", None) is None:
82+
headers["x-recording-upstream-base-uri"] = str(
83+
httpx.URL(scheme=original_url.scheme, netloc=original_url.netloc)
84+
)
85+
headers["x-recording-id"] = config.recording_id
86+
headers["x-recording-mode"] = config.recording_mode
87+
88+
response = super().send(request, **kwargs)
89+
90+
response.request.url = original_url
91+
return response
92+
93+
94+
# openai._base_client.SyncHttpxClientWrapper is default httpx.Client instantiated by openai
95+
openai._base_client.SyncHttpxClientWrapper = TestProxyHttpxClient

sdk/ai/azure-ai-generative/tests/conftest.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __openai_patcher import TestProxyConfig, TestProxyHttpxClient # isort: split
12
import asyncio
23
import base64
34
import os
@@ -6,7 +7,6 @@
67
import pytest
78
from azure.ai.generative.synthetic.qa import QADataGenerator
89

9-
import pytest
1010
from packaging import version
1111
from devtools_testutils import (
1212
FakeTokenCredential,
@@ -17,6 +17,8 @@
1717
is_live,
1818
set_custom_default_matcher,
1919
)
20+
from devtools_testutils.config import PROXY_URL
21+
from devtools_testutils.helpers import get_recording_id
2022
from devtools_testutils.proxy_fixtures import EnvironmentVariableSanitizer
2123

2224
from azure.ai.resources.client import AIClient
@@ -25,6 +27,18 @@
2527
from azure.identity import AzureCliCredential, ClientSecretCredential
2628

2729

30+
@pytest.fixture()
31+
def recorded_test(recorded_test):
32+
"""Route requests from the openai package to the test proxy."""
33+
34+
config = TestProxyConfig(
35+
recording_id=get_recording_id(), recording_mode="record" if is_live() else "playback", proxy_url=PROXY_URL
36+
)
37+
38+
39+
with TestProxyHttpxClient.record_with_proxy(config):
40+
yield recorded_test
41+
2842
@pytest.fixture()
2943
def ai_client(
3044
e2e_subscription_id: str,

sdk/ai/azure-ai-generative/tests/synthetic_qa/unittests/test_qa_data_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_export_format(self, qa_type, structure):
9191
qa_generator = QADataGenerator(model_config)
9292
qas = list(zip(questions, answers))
9393
filepath = os.path.join(pathlib.Path(__file__).parent.parent.resolve(), "data")
94-
output_file = os.path.join(filepath, f"test_{qa_type}_{structure}.jsonl")
94+
output_file = os.path.join(filepath, f"test_{qa_type.value}_{structure.value}.jsonl")
9595
qa_generator.export_to_file(output_file, qa_type, qas, structure)
9696

9797
if qa_type == QAType.CONVERSATION and structure == OutputStructure.CHAT_PROTOCOL:

sdk/ai/azure-ai-resources/dev_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
-e ../../ml/azure-ai-ml
66
pytest
77
pytest-xdist
8+
openai
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""Implementation of an httpx.Client that forwards traffic to the Azure SDK test-proxy.
2+
3+
.. note::
4+
5+
This module has side-effects!
6+
7+
Importing this module will replace the default httpx.Client used
8+
by the openai package with one that can redirect it's traffic
9+
to the Azure SDK test-proxy on demand.
10+
11+
"""
12+
from contextlib import contextmanager
13+
from typing import Iterable, Literal, Optional
14+
15+
import httpx
16+
import openai._base_client
17+
from typing_extensions import override
18+
from dataclasses import dataclass
19+
20+
21+
@dataclass
22+
class TestProxyConfig:
23+
recording_id: str
24+
"""The ID for the ongoing test recording."""
25+
26+
recording_mode: Literal["playback", "record"]
27+
"""The current recording mode."""
28+
29+
proxy_url: str
30+
"""The url for the Azure SDK test proxy."""
31+
32+
33+
class TestProxyHttpxClient(openai._base_client.SyncHttpxClientWrapper):
34+
recording_config: Optional[TestProxyConfig] = None
35+
36+
@classmethod
37+
def is_recording(cls) -> bool:
38+
"""Whether we are forwarding requests to the test proxy
39+
40+
:return: True if forwarding, False otherwise
41+
:rtype: bool
42+
"""
43+
return cls.recording_config is not None
44+
45+
@classmethod
46+
@contextmanager
47+
def record_with_proxy(cls, config: TestProxyConfig) -> Iterable[None]:
48+
"""Forward all requests made within the scope of context manager to test-proxy.
49+
50+
:param TestProxyConfig config: The test proxy configuration
51+
"""
52+
cls.recording_config = config
53+
54+
yield
55+
56+
cls.recording_config = None
57+
58+
@override
59+
def send(self, request: httpx.Request, **kwargs) -> httpx.Response:
60+
if self.is_recording():
61+
return self._send_to_proxy(request, **kwargs)
62+
else:
63+
return super().send(request, **kwargs)
64+
65+
def _send_to_proxy(self, request: httpx.Request, **kwargs) -> httpx.Response:
66+
"""Forwards a network request to the test proxy
67+
68+
:param httpx.Request request: The request to send
69+
:keyword **kwargs: The kwargs accepted by httpx.Client.send
70+
:return: The request's response
71+
:rtype: httpx.Response
72+
"""
73+
assert self.is_recording(), f"{self._send_to_proxy.__qualname__} should only be called while recording"
74+
config = self.recording_config
75+
original_url = request.url
76+
77+
request_path = original_url.copy_with(scheme="", netloc=b"")
78+
request.url = httpx.URL(config.proxy_url).join(request_path)
79+
80+
headers = request.headers
81+
if headers.get("x-recording-upstream-base-uri", None) is None:
82+
headers["x-recording-upstream-base-uri"] = str(
83+
httpx.URL(scheme=original_url.scheme, netloc=original_url.netloc)
84+
)
85+
headers["x-recording-id"] = config.recording_id
86+
headers["x-recording-mode"] = config.recording_mode
87+
88+
response = super().send(request, **kwargs)
89+
90+
response.request.url = original_url
91+
return response
92+
93+
94+
# openai._base_client.SyncHttpxClientWrapper is default httpx.Client instantiated by openai
95+
openai._base_client.SyncHttpxClientWrapper = TestProxyHttpxClient

sdk/ai/azure-ai-resources/tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __openai_patcher import TestProxyConfig, TestProxyHttpxClient # isort: split
12
import asyncio
23
import base64
34
import os
@@ -17,6 +18,8 @@
1718
is_live,
1819
set_custom_default_matcher,
1920
)
21+
from devtools_testutils.config import PROXY_URL
22+
from devtools_testutils.helpers import get_recording_id
2023
from devtools_testutils.proxy_fixtures import (
2124
EnvironmentVariableSanitizer,
2225
VariableRecorder
@@ -38,6 +41,17 @@ def generate_random_string():
3841
return generate_random_string
3942

4043

44+
@pytest.fixture()
45+
def recorded_test(recorded_test):
46+
"""Route requests from the openai package to the test proxy."""
47+
48+
config = TestProxyConfig(
49+
recording_id=get_recording_id(), recording_mode="record" if is_live() else "playback", proxy_url=PROXY_URL
50+
)
51+
52+
53+
with TestProxyHttpxClient.record_with_proxy(config):
54+
yield recorded_test
4155

4256
@pytest.fixture()
4357
def ai_client(

0 commit comments

Comments
 (0)