|
4 | 4 | # Licensed under the MIT License. See License.txt in the project root for
|
5 | 5 | # license information.
|
6 | 6 | # --------------------------------------------------------------------------
|
| 7 | +import time |
| 8 | +import asyncio |
| 9 | +import json |
| 10 | +import httpx |
7 | 11 | import os
|
8 | 12 | import pytest
|
9 | 13 | import importlib
|
|
31 | 35 | WHISPER_AZURE = "whisper_azure"
|
32 | 36 | WHISPER_AZURE_AD = "whisper_azuread"
|
33 | 37 | WHISPER_ALL = ["whisper_azure", "whisper_azuread", "openai"]
|
| 38 | +DALLE_AZURE = "dalle_azure" |
34 | 39 |
|
35 | 40 | # Environment variable keys
|
36 | 41 | ENV_AZURE_OPENAI_ENDPOINT = "AZ_OPENAI_ENDPOINT"
|
@@ -101,6 +106,121 @@ def azure_openai_creds():
|
101 | 106 |
|
102 | 107 | # openai>=1.0.0 ---------------------------------------------------------------------------
|
103 | 108 |
|
| 109 | +class CustomHTTPTransport(httpx.HTTPTransport): |
| 110 | + """Temp stop-gap support for DALL-E""" |
| 111 | + def handle_request( |
| 112 | + self, |
| 113 | + request: httpx.Request, |
| 114 | + ) -> httpx.Response: |
| 115 | + if "images/generations" in request.url.path and request.url.params[ |
| 116 | + "api-version" |
| 117 | + ] in [ |
| 118 | + "2023-06-01-preview", |
| 119 | + "2023-07-01-preview", |
| 120 | + "2023-08-01-preview", |
| 121 | + "2023-09-01-preview", |
| 122 | + "2023-10-01-preview", |
| 123 | + ]: |
| 124 | + request.url = request.url.copy_with(path="/openai/images/generations:submit") |
| 125 | + response = super().handle_request(request) |
| 126 | + operation_location_url = response.headers["operation-location"] |
| 127 | + request.url = httpx.URL(operation_location_url) |
| 128 | + request.method = "GET" |
| 129 | + response = super().handle_request(request) |
| 130 | + response.read() |
| 131 | + |
| 132 | + timeout_secs: int = 120 |
| 133 | + start_time = time.time() |
| 134 | + while response.json()["status"] not in ["succeeded", "failed"]: |
| 135 | + if time.time() - start_time > timeout_secs: |
| 136 | + timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}} |
| 137 | + return httpx.Response( |
| 138 | + status_code=400, |
| 139 | + headers=response.headers, |
| 140 | + content=json.dumps(timeout).encode("utf-8"), |
| 141 | + request=request, |
| 142 | + ) |
| 143 | + |
| 144 | + time.sleep(int(response.headers.get("retry-after")) or 10) |
| 145 | + response = super().handle_request(request) |
| 146 | + response.read() |
| 147 | + |
| 148 | + if response.json()["status"] == "failed": |
| 149 | + error_data = response.json() |
| 150 | + return httpx.Response( |
| 151 | + status_code=400, |
| 152 | + headers=response.headers, |
| 153 | + content=json.dumps(error_data).encode("utf-8"), |
| 154 | + request=request, |
| 155 | + ) |
| 156 | + |
| 157 | + result = response.json()["result"] |
| 158 | + return httpx.Response( |
| 159 | + status_code=200, |
| 160 | + headers=response.headers, |
| 161 | + content=json.dumps(result).encode("utf-8"), |
| 162 | + request=request, |
| 163 | + ) |
| 164 | + return super().handle_request(request) |
| 165 | + |
| 166 | + |
| 167 | +class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport): |
| 168 | + """Temp stop-gap support for DALL-E""" |
| 169 | + async def handle_async_request( |
| 170 | + self, |
| 171 | + request: httpx.Request, |
| 172 | + ) -> httpx.Response: |
| 173 | + if "images/generations" in request.url.path and request.url.params[ |
| 174 | + "api-version" |
| 175 | + ] in [ |
| 176 | + "2023-06-01-preview", |
| 177 | + "2023-07-01-preview", |
| 178 | + "2023-08-01-preview", |
| 179 | + "2023-09-01-preview", |
| 180 | + "2023-10-01-preview", |
| 181 | + ]: |
| 182 | + request.url = request.url.copy_with(path="/openai/images/generations:submit") |
| 183 | + response = await super().handle_async_request(request) |
| 184 | + operation_location_url = response.headers["operation-location"] |
| 185 | + request.url = httpx.URL(operation_location_url) |
| 186 | + request.method = "GET" |
| 187 | + response = await super().handle_async_request(request) |
| 188 | + await response.aread() |
| 189 | + |
| 190 | + timeout_secs: int = 120 |
| 191 | + start_time = time.time() |
| 192 | + while response.json()["status"] not in ["succeeded", "failed"]: |
| 193 | + if time.time() - start_time > timeout_secs: |
| 194 | + timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}} |
| 195 | + return httpx.Response( |
| 196 | + status_code=400, |
| 197 | + headers=response.headers, |
| 198 | + content=json.dumps(timeout).encode("utf-8"), |
| 199 | + request=request, |
| 200 | + ) |
| 201 | + |
| 202 | + await asyncio.sleep(int(response.headers.get("retry-after")) or 10) |
| 203 | + response = await super().handle_async_request(request) |
| 204 | + await response.aread() |
| 205 | + |
| 206 | + if response.json()["status"] == "failed": |
| 207 | + error_data = response.json() |
| 208 | + return httpx.Response( |
| 209 | + status_code=400, |
| 210 | + headers=response.headers, |
| 211 | + content=json.dumps(error_data).encode("utf-8"), |
| 212 | + request=request, |
| 213 | + ) |
| 214 | + |
| 215 | + result = response.json()["result"] |
| 216 | + return httpx.Response( |
| 217 | + status_code=200, |
| 218 | + headers=response.headers, |
| 219 | + content=json.dumps(result).encode("utf-8"), |
| 220 | + request=request, |
| 221 | + ) |
| 222 | + return await super().handle_async_request(request) |
| 223 | + |
104 | 224 | @pytest.fixture
|
105 | 225 | def client(api_type):
|
106 | 226 | if os.getenv(ENV_OPENAI_TEST_MODE, "v1") != "v1":
|
@@ -133,6 +253,13 @@ def client(api_type):
|
133 | 253 | azure_ad_token_provider=get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"),
|
134 | 254 | api_version=ENV_AZURE_OPENAI_API_VERSION,
|
135 | 255 | )
|
| 256 | + elif api_type == "dalle_azure": |
| 257 | + client = openai.AzureOpenAI( |
| 258 | + azure_endpoint=os.getenv(ENV_AZURE_OPENAI_ENDPOINT), |
| 259 | + api_key=os.getenv(ENV_AZURE_OPENAI_KEY), |
| 260 | + api_version=ENV_AZURE_OPENAI_API_VERSION, |
| 261 | + http_client=httpx.Client(transport=CustomHTTPTransport()) |
| 262 | + ) |
136 | 263 |
|
137 | 264 | return client
|
138 | 265 |
|
@@ -169,6 +296,13 @@ def client_async(api_type):
|
169 | 296 | azure_ad_token_provider=get_bearer_token_provider_async(AsyncDefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"),
|
170 | 297 | api_version=ENV_AZURE_OPENAI_API_VERSION,
|
171 | 298 | )
|
| 299 | + elif api_type == "dalle_azure": |
| 300 | + client = openai.AsyncAzureOpenAI( |
| 301 | + azure_endpoint=os.getenv(ENV_AZURE_OPENAI_ENDPOINT), |
| 302 | + api_key=os.getenv(ENV_AZURE_OPENAI_KEY), |
| 303 | + api_version=ENV_AZURE_OPENAI_API_VERSION, |
| 304 | + http_client=httpx.AsyncClient(transport=AsyncCustomHTTPTransport()) |
| 305 | + ) |
172 | 306 |
|
173 | 307 | return client
|
174 | 308 |
|
|
0 commit comments