Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,21 @@ def _build_request(
) -> httpx.Request:
if options.url in _deployments_endpoints and is_mapping(options.json_data):
model = options.json_data.get("model")
if model is not None and not "/deployments" in str(self.base_url):
if model is not None and "/deployments" not in str(self.base_url.path):
options.url = f"/deployments/{model}{options.url}"

return super()._build_request(options)

@override
def _prepare_url(self, url: str) -> httpx.URL:
if "/deployments" in str(self.base_url.path) and url not in _deployments_endpoints:
merge_url = httpx.URL(url)
if merge_url.is_relative_url:
merge_path = f"{self.base_url.path.rsplit('/deployments', maxsplit=1)[0]}/{merge_url.path.lstrip('/')}"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than adding it and then removing it, did we consider adding it only when needed?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that would be the better way of dealing with it. However, we'd have to change the current value of client.base_url to take that approach which could break users.

Copy link
Collaborator

@johanste johanste Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we split on/find '/deployments/'? I guess it is unlikely that one would name a deployment "deployments" (and we do a maxsplit=1 so it is only a problem if we have a deployments "to the right" of the intended split point).

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, the base_url property enforces a trailing slash so it will always be there, regardless of whether the user added it.

return self.base_url.copy_with(path=merge_path)

return super()._prepare_url(url)


class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
@overload
Expand Down
192 changes: 177 additions & 15 deletions tests/lib/test_azure.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Union
from typing_extensions import Literal

Expand All @@ -22,21 +24,6 @@
)


@pytest.mark.parametrize("client", [sync_client, async_client])
def test_implicit_deployment_path(client: Client) -> None:
req = client._build_request(
FinalRequestOptions.construct(
method="post",
url="/chat/completions",
json_data={"model": "my-deployment-model"},
)
)
assert (
req.url
== "https://example-resource.azure.openai.com/openai/deployments/my-deployment-model/chat/completions?api-version=2023-07-01"
)


@pytest.mark.parametrize(
"client,method",
[
Expand Down Expand Up @@ -64,3 +51,178 @@ def test_client_copying_override_options(client: Client) -> None:
api_version="2022-05-01",
)
assert copied._custom_query == {"api-version": "2022-05-01"}


@pytest.mark.parametrize(
"client,base_url,api_path,json_data,expected",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can come up with a couple of additional tests here - e.g. have a dns name with deployments in it, add a deployment called deployments, add a couple of deployments to the azure_endpoint url etc.

[
(
AzureOpenAI(
api_version="2024-02-01",
api_key="example API key",
azure_endpoint="https://example-resource.azure.openai.com",
),
"https://example-resource.azure.openai.com/openai/",
"/chat/completions",
{"model": "my-deployment"},
"https://example-resource.azure.openai.com/openai/deployments/my-deployment/chat/completions?api-version=2024-02-01"
),
(
AzureOpenAI(
api_version="2024-02-01",
api_key="example API key",
azure_endpoint="https://example-resource.azure.openai.com",
),
"https://example-resource.azure.openai.com/openai/",
"/models",
{},
"https://example-resource.azure.openai.com/openai/models?api-version=2024-02-01"
),
(
AzureOpenAI(
api_version="2024-02-01",
api_key="example API key",
azure_endpoint="https://example-resource.azure.openai.com",
azure_deployment="my-deployment"
),
"https://example-resource.azure.openai.com/openai/deployments/my-deployment/",
"/chat/completions",
{"model": "placeholder"},
"https://example-resource.azure.openai.com/openai/deployments/my-deployment/chat/completions?api-version=2024-02-01"
),
(
AzureOpenAI(
api_version="2024-02-01",
api_key="example API key",
azure_endpoint="https://example-resource.azure.openai.com",
azure_deployment="my-deployment"
),
"https://example-resource.azure.openai.com/openai/deployments/my-deployment/",
"/models",
{},
"https://example-resource.azure.openai.com/openai/models?api-version=2024-02-01"
),
(
AzureOpenAI(
api_version="2024-02-01",
api_key="example API key",
base_url="https://example.azure-api.net/PTU/deployments/my-deployment/",
),
"https://example.azure-api.net/PTU/deployments/my-deployment/",
"/chat/completions",
{"model": "placeholder"},
"https://example.azure-api.net/PTU/deployments/my-deployment/chat/completions?api-version=2024-02-01"
),
(
AzureOpenAI(
api_version="2024-02-01",
api_key="example API key",
base_url="https://example.azure-api.net/PTU/",
),
"https://example.azure-api.net/PTU/",
"/chat/completions",
{"model": "my-deployment"},
"https://example.azure-api.net/PTU/deployments/my-deployment/chat/completions?api-version=2024-02-01"
),
(
AzureOpenAI(
api_version="2024-02-01",
api_key="example API key",
base_url="https://example.azure-api.net/PTU/deployments/my-deployment/",
),
"https://example.azure-api.net/PTU/deployments/my-deployment/",
"/models",
{},
"https://example.azure-api.net/PTU/models?api-version=2024-02-01"
),
(
AsyncAzureOpenAI(
api_version="2024-02-01",
api_key="example API key",
azure_endpoint="https://example-resource.azure.openai.com",
),
"https://example-resource.azure.openai.com/openai/",
"/chat/completions",
{"model": "my-deployment"},
"https://example-resource.azure.openai.com/openai/deployments/my-deployment/chat/completions?api-version=2024-02-01"
),
(
AsyncAzureOpenAI(
api_version="2024-02-01",
api_key="example API key",
azure_endpoint="https://example-resource.azure.openai.com",
),
"https://example-resource.azure.openai.com/openai/",
"/models",
{},
"https://example-resource.azure.openai.com/openai/models?api-version=2024-02-01"
),
(
AsyncAzureOpenAI(
api_version="2024-02-01",
api_key="example API key",
azure_endpoint="https://example-resource.azure.openai.com",
azure_deployment="my-deployment"
),
"https://example-resource.azure.openai.com/openai/deployments/my-deployment/",
"/chat/completions",
{"model": "placeholder"},
"https://example-resource.azure.openai.com/openai/deployments/my-deployment/chat/completions?api-version=2024-02-01"
),
(
AsyncAzureOpenAI(
api_version="2024-02-01",
api_key="example API key",
azure_endpoint="https://example-resource.azure.openai.com",
azure_deployment="my-deployment"
),
"https://example-resource.azure.openai.com/openai/deployments/my-deployment/",
"/models",
{},
"https://example-resource.azure.openai.com/openai/models?api-version=2024-02-01"
),
(
AsyncAzureOpenAI(
api_version="2024-02-01",
api_key="example API key",
base_url="https://example.azure-api.net/PTU/deployments/my-deployment/",
),
"https://example.azure-api.net/PTU/deployments/my-deployment/",
"/chat/completions",
{"model": "placeholder"},
"https://example.azure-api.net/PTU/deployments/my-deployment/chat/completions?api-version=2024-02-01"
),
(
AsyncAzureOpenAI(
api_version="2024-02-01",
api_key="example API key",
base_url="https://example.azure-api.net/PTU/",
),
"https://example.azure-api.net/PTU/",
"/chat/completions",
{"model": "my-deployment"},
"https://example.azure-api.net/PTU/deployments/my-deployment/chat/completions?api-version=2024-02-01"
),
(
AsyncAzureOpenAI(
api_version="2024-02-01",
api_key="example API key",
base_url="https://example.azure-api.net/PTU/deployments/my-deployment/",
),
"https://example.azure-api.net/PTU/deployments/my-deployment/",
"/models",
{},
"https://example.azure-api.net/PTU/models?api-version=2024-02-01"
),
],
)
def test_client_prepare_url(client: Client, base_url: str, api_path: str, json_data: dict[str, str], expected: str) -> None:
req = client._build_request(
FinalRequestOptions.construct(
method="post",
url=api_path,
json_data=json_data,
)
)
assert req.url == expected
assert client.base_url == base_url