Skip to content

Commit f31b0ee

Browse files
test stopgap support for dall-e in v1 (Azure#33031)
1 parent daf44bb commit f31b0ee

File tree

3 files changed

+144
-10
lines changed

3 files changed

+144
-10
lines changed

sdk/openai/azure-openai/tests/conftest.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
# Licensed under the MIT License. See License.txt in the project root for
55
# license information.
66
# --------------------------------------------------------------------------
7+
import time
8+
import asyncio
9+
import json
10+
import httpx
711
import os
812
import pytest
913
import importlib
@@ -31,6 +35,7 @@
3135
WHISPER_AZURE = "whisper_azure"
3236
WHISPER_AZURE_AD = "whisper_azuread"
3337
WHISPER_ALL = ["whisper_azure", "whisper_azuread", "openai"]
38+
DALLE_AZURE = "dalle_azure"
3439

3540
# Environment variable keys
3641
ENV_AZURE_OPENAI_ENDPOINT = "AZ_OPENAI_ENDPOINT"
@@ -101,6 +106,121 @@ def azure_openai_creds():
101106

102107
# openai>=1.0.0 ---------------------------------------------------------------------------
103108

109+
class CustomHTTPTransport(httpx.HTTPTransport):
110+
"""Temp stop-gap support for DALL-E"""
111+
def handle_request(
112+
self,
113+
request: httpx.Request,
114+
) -> httpx.Response:
115+
if "images/generations" in request.url.path and request.url.params[
116+
"api-version"
117+
] in [
118+
"2023-06-01-preview",
119+
"2023-07-01-preview",
120+
"2023-08-01-preview",
121+
"2023-09-01-preview",
122+
"2023-10-01-preview",
123+
]:
124+
request.url = request.url.copy_with(path="/openai/images/generations:submit")
125+
response = super().handle_request(request)
126+
operation_location_url = response.headers["operation-location"]
127+
request.url = httpx.URL(operation_location_url)
128+
request.method = "GET"
129+
response = super().handle_request(request)
130+
response.read()
131+
132+
timeout_secs: int = 120
133+
start_time = time.time()
134+
while response.json()["status"] not in ["succeeded", "failed"]:
135+
if time.time() - start_time > timeout_secs:
136+
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
137+
return httpx.Response(
138+
status_code=400,
139+
headers=response.headers,
140+
content=json.dumps(timeout).encode("utf-8"),
141+
request=request,
142+
)
143+
144+
time.sleep(int(response.headers.get("retry-after")) or 10)
145+
response = super().handle_request(request)
146+
response.read()
147+
148+
if response.json()["status"] == "failed":
149+
error_data = response.json()
150+
return httpx.Response(
151+
status_code=400,
152+
headers=response.headers,
153+
content=json.dumps(error_data).encode("utf-8"),
154+
request=request,
155+
)
156+
157+
result = response.json()["result"]
158+
return httpx.Response(
159+
status_code=200,
160+
headers=response.headers,
161+
content=json.dumps(result).encode("utf-8"),
162+
request=request,
163+
)
164+
return super().handle_request(request)
165+
166+
167+
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
168+
"""Temp stop-gap support for DALL-E"""
169+
async def handle_async_request(
170+
self,
171+
request: httpx.Request,
172+
) -> httpx.Response:
173+
if "images/generations" in request.url.path and request.url.params[
174+
"api-version"
175+
] in [
176+
"2023-06-01-preview",
177+
"2023-07-01-preview",
178+
"2023-08-01-preview",
179+
"2023-09-01-preview",
180+
"2023-10-01-preview",
181+
]:
182+
request.url = request.url.copy_with(path="/openai/images/generations:submit")
183+
response = await super().handle_async_request(request)
184+
operation_location_url = response.headers["operation-location"]
185+
request.url = httpx.URL(operation_location_url)
186+
request.method = "GET"
187+
response = await super().handle_async_request(request)
188+
await response.aread()
189+
190+
timeout_secs: int = 120
191+
start_time = time.time()
192+
while response.json()["status"] not in ["succeeded", "failed"]:
193+
if time.time() - start_time > timeout_secs:
194+
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
195+
return httpx.Response(
196+
status_code=400,
197+
headers=response.headers,
198+
content=json.dumps(timeout).encode("utf-8"),
199+
request=request,
200+
)
201+
202+
await asyncio.sleep(int(response.headers.get("retry-after")) or 10)
203+
response = await super().handle_async_request(request)
204+
await response.aread()
205+
206+
if response.json()["status"] == "failed":
207+
error_data = response.json()
208+
return httpx.Response(
209+
status_code=400,
210+
headers=response.headers,
211+
content=json.dumps(error_data).encode("utf-8"),
212+
request=request,
213+
)
214+
215+
result = response.json()["result"]
216+
return httpx.Response(
217+
status_code=200,
218+
headers=response.headers,
219+
content=json.dumps(result).encode("utf-8"),
220+
request=request,
221+
)
222+
return await super().handle_async_request(request)
223+
104224
@pytest.fixture
105225
def client(api_type):
106226
if os.getenv(ENV_OPENAI_TEST_MODE, "v1") != "v1":
@@ -133,6 +253,13 @@ def client(api_type):
133253
azure_ad_token_provider=get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"),
134254
api_version=ENV_AZURE_OPENAI_API_VERSION,
135255
)
256+
elif api_type == "dalle_azure":
257+
client = openai.AzureOpenAI(
258+
azure_endpoint=os.getenv(ENV_AZURE_OPENAI_ENDPOINT),
259+
api_key=os.getenv(ENV_AZURE_OPENAI_KEY),
260+
api_version=ENV_AZURE_OPENAI_API_VERSION,
261+
http_client=httpx.Client(transport=CustomHTTPTransport())
262+
)
136263

137264
return client
138265

@@ -169,6 +296,13 @@ def client_async(api_type):
169296
azure_ad_token_provider=get_bearer_token_provider_async(AsyncDefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"),
170297
api_version=ENV_AZURE_OPENAI_API_VERSION,
171298
)
299+
elif api_type == "dalle_azure":
300+
client = openai.AsyncAzureOpenAI(
301+
azure_endpoint=os.getenv(ENV_AZURE_OPENAI_ENDPOINT),
302+
api_key=os.getenv(ENV_AZURE_OPENAI_KEY),
303+
api_version=ENV_AZURE_OPENAI_API_VERSION,
304+
http_client=httpx.AsyncClient(transport=AsyncCustomHTTPTransport())
305+
)
172306

173307
return client
174308

sdk/openai/azure-openai/tests/v1_tests/test_dall_e.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
import pytest
77
import openai
88
from devtools_testutils import AzureRecordedTestCase
9-
from conftest import configure, AZURE, OPENAI, ALL
9+
from conftest import configure, AZURE, OPENAI, ALL, DALLE_AZURE
1010

1111

1212
class TestDallE(AzureRecordedTestCase):
1313

1414
@configure
1515
# @pytest.mark.parametrize("api_type", ALL) # re-enable when supported
16-
@pytest.mark.parametrize("api_type", [OPENAI])
16+
@pytest.mark.parametrize("api_type", [OPENAI, DALLE_AZURE])
1717
def test_image_create(self, client, azure_openai_creds, api_type, **kwargs):
1818
image = client.images.generate(
1919
prompt="a cute baby seal"
@@ -24,7 +24,7 @@ def test_image_create(self, client, azure_openai_creds, api_type, **kwargs):
2424

2525
@configure
2626
# @pytest.mark.parametrize("api_type", [AZURE, OPENAI]) # re-enable when supported
27-
@pytest.mark.parametrize("api_type", [OPENAI])
27+
@pytest.mark.parametrize("api_type", [OPENAI, DALLE_AZURE])
2828
def test_image_create_n(self, client, azure_openai_creds, api_type, **kwargs):
2929
image = client.images.generate(
3030
prompt="a cute baby seal",
@@ -37,7 +37,7 @@ def test_image_create_n(self, client, azure_openai_creds, api_type, **kwargs):
3737

3838
@configure
3939
# @pytest.mark.parametrize("api_type", [AZURE, OPENAI]) # re-enable when supported
40-
@pytest.mark.parametrize("api_type", [OPENAI])
40+
@pytest.mark.parametrize("api_type", [OPENAI, DALLE_AZURE])
4141
def test_image_create_size(self, client, azure_openai_creds, api_type, **kwargs):
4242
image = client.images.generate(
4343
prompt="a cute baby seal",
@@ -60,7 +60,7 @@ def test_image_create_response_format(self, client, azure_openai_creds, api_type
6060

6161
@configure
6262
# @pytest.mark.parametrize("api_type", [AZURE, OPENAI]) # re-enable when supported
63-
@pytest.mark.parametrize("api_type", [OPENAI])
63+
@pytest.mark.parametrize("api_type", [OPENAI, DALLE_AZURE])
6464
def test_image_create_user(self, client, azure_openai_creds, api_type, **kwargs):
6565
image = client.images.generate(
6666
prompt="a cute baby seal",

sdk/openai/azure-openai/tests/v1_tests/test_dall_e_async.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
import pytest
77
import openai
88
from devtools_testutils import AzureRecordedTestCase
9-
from conftest import AZURE, OPENAI, ALL, configure_async
9+
from conftest import AZURE, OPENAI, ALL, configure_async, DALLE_AZURE
1010

1111

1212
class TestDallEAsync(AzureRecordedTestCase):
1313

1414
@configure_async
1515
@pytest.mark.asyncio
1616
# @pytest.mark.parametrize("api_type", ALL) # re-enable when supported
17-
@pytest.mark.parametrize("api_type", [OPENAI])
17+
@pytest.mark.parametrize("api_type", [OPENAI, DALLE_AZURE])
1818
async def test_image_create(self, client_async, azure_openai_creds, api_type, **kwargs):
1919
image = await client_async.images.generate(
2020
prompt="a cute baby seal"
@@ -26,7 +26,7 @@ async def test_image_create(self, client_async, azure_openai_creds, api_type, **
2626
@configure_async
2727
@pytest.mark.asyncio
2828
# @pytest.mark.parametrize("api_type", [AZURE, OPENAI]) # re-enable when supported
29-
@pytest.mark.parametrize("api_type", [OPENAI])
29+
@pytest.mark.parametrize("api_type", [OPENAI, DALLE_AZURE])
3030
async def test_image_create_n(self, client_async, azure_openai_creds, api_type, **kwargs):
3131
image = await client_async.images.generate(
3232
prompt="a cute baby seal",
@@ -40,7 +40,7 @@ async def test_image_create_n(self, client_async, azure_openai_creds, api_type,
4040
@configure_async
4141
@pytest.mark.asyncio
4242
# @pytest.mark.parametrize("api_type", [AZURE, OPENAI]) # re-enable when supported
43-
@pytest.mark.parametrize("api_type", [OPENAI])
43+
@pytest.mark.parametrize("api_type", [OPENAI, DALLE_AZURE])
4444
async def test_image_create_size(self, client_async, azure_openai_creds, api_type, **kwargs):
4545
image = await client_async.images.generate(
4646
prompt="a cute baby seal",
@@ -65,7 +65,7 @@ async def test_image_create_response_format(self, client_async, azure_openai_cre
6565
@configure_async
6666
@pytest.mark.asyncio
6767
# @pytest.mark.parametrize("api_type", [AZURE, OPENAI]) # re-enable when supported
68-
@pytest.mark.parametrize("api_type", [OPENAI])
68+
@pytest.mark.parametrize("api_type", [OPENAI, DALLE_AZURE])
6969
async def test_image_create_user(self, client_async, azure_openai_creds, api_type, **kwargs):
7070
image = await client_async.images.generate(
7171
prompt="a cute baby seal",

0 commit comments

Comments
 (0)