Skip to content

Commit 444ed8b

Browse files
authored
Azure AI Inference SDK - Beta 2 updates (#36163)
The main reason for this release, shortly after the first release: - Add strongly-typed `model` as an optional input argument to the `complete` method of `ChatCompletionsClient`. This is required for a high-visiblity project. For this project, developers must set `model`. Breaking change (noted in CHANGELOG.md): - The field `input_tokens` was removed from class `EmbeddingsUsage`, as this was never defined in the REST API and the service never returned this value. Other changes in this release: - Addressing some test dept (work in progress) - Add tests for setting `model_extras` for sync and async clients. Make sure the additional parameters appear at the root of the JSON request payload, and make sure the `unknown_parameters` HTTP request header was set to `pass_through`. - Add tests to validate serialization of a dummy chat completion request that includes all type of input objects. This is a regression test (no service response needed), as the test looks at the JSON request payload and compared to a hard-coded expected string, that was previously verified by hand. This test includes the new `model` argument, as well as all other arguments defined by the REST API. It will catch any regressions in hand-written code. - Update ref docs to remove mentioning of the old `extras` input argument to chat completions in hand-written code. The name was changed to `model_extras` before the first release, but looks like we still had some left-over ref-doc comments that describe the no-longer-existing argument. - Remove unused function from the sample `sample_chat_completions_with_image_data.py`. Forgot to do that in the first release. - Minor changes to root README.md - Indicate that `complete` method with `stream=True` returns `Iterable[StreamingChatCompletionsUpdate]` for the synchronous `ChatComletionsClient`, and `Iterable[StreamingChatCompletionsUpdate]` for the asynchronous `ChatCompletionsClient`. Per feedback from Anna T. - Update environment variable names used by sample code and test to start with "AZURE_AI" as common elsewhere, per feedback from Rob C.
1 parent f792549 commit 444ed8b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+594
-479
lines changed

sdk/ai/azure-ai-inference/CHANGELOG.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# Release History
22

3-
## 1.0.0b2 (Unreleased)
3+
## 1.0.0b2 (2024-06-24)
44

55
### Features Added
66

7-
### Breaking Changes
7+
Add `model` as an optional input argument to the `complete` method of `ChatCompletionsClient`.
88

9-
### Bugs Fixed
9+
### Breaking Changes
1010

11-
### Other Changes
11+
The field `input_tokens` was removed from class `EmbeddingsUsage`, as this was never defined in the
12+
REST API and the service never returned this value.
1213

1314
## 1.0.0b1 (2024-06-11)
1415

sdk/ai/azure-ai-inference/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ To load an asynchronous client, import the `load_client` function from `azure.ai
123123

124124
Entra ID authentication is also supported by the `load_client` function. Replace the key authentication above with `credential=DefaultAzureCredential()` for example.
125125

126-
### Getting AI model information
126+
### Get AI model information
127127

128128
All clients provide a `get_model_info` method to retrive AI model information. This makes a REST call to the `/info` route on the provided endpoint, as documented in [the REST API reference](https://learn.microsoft.com/azure/ai-studio/reference/reference-model-inference-info).
129129

sdk/ai/azure-ai-inference/azure/ai/inference/_client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from copy import deepcopy
1010
from typing import Any, TYPE_CHECKING, Union
11+
from typing_extensions import Self
1112

1213
from azure.core import PipelineClient
1314
from azure.core.credentials import AzureKeyCredential
@@ -101,7 +102,7 @@ def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs:
101102
def close(self) -> None:
102103
self._client.close()
103104

104-
def __enter__(self) -> "ChatCompletionsClient":
105+
def __enter__(self) -> Self:
105106
self._client.__enter__()
106107
return self
107108

@@ -179,7 +180,7 @@ def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs:
179180
def close(self) -> None:
180181
self._client.close()
181182

182-
def __enter__(self) -> "EmbeddingsClient":
183+
def __enter__(self) -> Self:
183184
self._client.__enter__()
184185
return self
185186

@@ -257,7 +258,7 @@ def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs:
257258
def close(self) -> None:
258259
self._client.close()
259260

260-
def __enter__(self) -> "ImageEmbeddingsClient":
261+
def __enter__(self) -> Self:
261262
self._client.__enter__()
262263
return self
263264

sdk/ai/azure-ai-inference/azure/ai/inference/_model_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,5 +883,6 @@ def rest_discriminator(
883883
*,
884884
name: typing.Optional[str] = None,
885885
type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin
886+
visibility: typing.Optional[typing.List[str]] = None,
886887
) -> typing.Any:
887-
return _RestField(name=name, type=type, is_discriminator=True)
888+
return _RestField(name=name, type=type, is_discriminator=True, visibility=visibility)

sdk/ai/azure-ai-inference/azure/ai/inference/_operations/_operations.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def _complete(
208208
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
209209
] = None,
210210
seed: Optional[int] = None,
211+
model: Optional[str] = None,
211212
**kwargs: Any
212213
) -> _models.ChatCompletions: ...
213214
@overload
@@ -240,6 +241,7 @@ def _complete(
240241
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
241242
] = None,
242243
seed: Optional[int] = None,
244+
model: Optional[str] = None,
243245
**kwargs: Any
244246
) -> _models.ChatCompletions:
245247
# pylint: disable=line-too-long
@@ -317,9 +319,12 @@ def _complete(
317319
~azure.ai.inference.models.ChatCompletionsNamedToolSelection
318320
:keyword seed: If specified, the system will make a best effort to sample deterministically
319321
such that repeated requests with the
320-
same seed and parameters should return the same result. Determinism is not guaranteed.".
321-
Default value is None.
322+
same seed and parameters should return the same result. Determinism is not guaranteed. Default
323+
value is None.
322324
:paramtype seed: int
325+
:keyword model: ID of the specific AI model to use, if more than one model is available on the
326+
endpoint. Default value is None.
327+
:paramtype model: str
323328
:return: ChatCompletions. The ChatCompletions is compatible with MutableMapping
324329
:rtype: ~azure.ai.inference.models.ChatCompletions
325330
:raises ~azure.core.exceptions.HttpResponseError:
@@ -338,6 +343,8 @@ def _complete(
338343
frequency increases and decrease the likelihood of the model repeating the same
339344
statements verbatim. Supported range is [-2, 2].
340345
"max_tokens": 0, # Optional. The maximum number of tokens to generate.
346+
"model": "str", # Optional. ID of the specific AI model to use, if more than
347+
one model is available on the endpoint.
341348
"presence_penalty": 0.0, # Optional. A value that influences the probability
342349
of generated tokens appearing based on their existing presence in generated text.
343350
Positive values will make tokens less likely to appear when they already exist
@@ -348,7 +355,7 @@ def _complete(
348355
"json_object".
349356
"seed": 0, # Optional. If specified, the system will make a best effort to
350357
sample deterministically such that repeated requests with the same seed and
351-
parameters should return the same result. Determinism is not guaranteed.".
358+
parameters should return the same result. Determinism is not guaranteed.
352359
"stop": [
353360
"str" # Optional. A collection of textual sequences that will end
354361
completions generation.
@@ -435,6 +442,7 @@ def _complete(
435442
"frequency_penalty": frequency_penalty,
436443
"max_tokens": max_tokens,
437444
"messages": messages,
445+
"model": model,
438446
"presence_penalty": presence_penalty,
439447
"response_format": response_format,
440448
"seed": seed,

sdk/ai/azure-ai-inference/azure/ai/inference/_patch.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import sys
2525

2626
from io import IOBase
27-
from typing import Any, Dict, Union, IO, List, Literal, Optional, overload, Type, TYPE_CHECKING
27+
from typing import Any, Dict, Union, IO, List, Literal, Optional, overload, Type, TYPE_CHECKING, Iterable
2828

2929
from azure.core.pipeline import PipelineResponse
3030
from azure.core.credentials import AzureKeyCredential
@@ -75,7 +75,7 @@
7575

7676
def load_client(
7777
endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any
78-
) -> Union[ChatCompletionsClientGenerated, EmbeddingsClientGenerated, ImageEmbeddingsClientGenerated]:
78+
) -> Union["ChatCompletionsClient", "EmbeddingsClient", "ImageEmbeddingsClient"]:
7979
"""
8080
Load a client from a given endpoint URL. The method makes a REST API call to the `/info` route
8181
on the given endpoint, to determine the model type and therefore which client to instantiate.
@@ -90,7 +90,7 @@ def load_client(
9090
"2024-05-01-preview". Note that overriding this default value may result in unsupported
9191
behavior.
9292
:paramtype api_version: str
93-
:return: The appropriate client associated with the given endpoint
93+
:return: The appropriate synchronous client associated with the given endpoint
9494
:rtype: ~azure.ai.inference.ChatCompletionsClient or ~azure.ai.inference.EmbeddingsClient
9595
or ~azure.ai.inference.ImageEmbeddingsClient
9696
:raises ~azure.core.exceptions.HttpResponseError
@@ -110,7 +110,9 @@ def load_client(
110110
# TODO: Remove "completions" and "embedding" once Mistral Large and Cohere fixes their model type
111111
if model_info.model_type in (_models.ModelType.CHAT, "completion"):
112112
chat_completion_client = ChatCompletionsClient(endpoint, credential, **kwargs)
113-
chat_completion_client._model_info = model_info # pylint: disable=protected-access,attribute-defined-outside-init
113+
chat_completion_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init
114+
model_info
115+
)
114116
return chat_completion_client
115117

116118
if model_info.model_type in (_models.ModelType.EMBEDDINGS, "embedding"):
@@ -120,7 +122,9 @@ def load_client(
120122

121123
if model_info.model_type == _models.ModelType.IMAGE_EMBEDDINGS:
122124
image_embedding_client = ImageEmbeddingsClient(endpoint, credential, **kwargs)
123-
image_embedding_client._model_info = model_info # pylint: disable=protected-access,attribute-defined-outside-init
125+
image_embedding_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init
126+
model_info
127+
)
124128
return image_embedding_client
125129

126130
raise ValueError(f"No client available to support AI model type `{model_info.model_type}`")
@@ -165,6 +169,7 @@ def complete(
165169
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
166170
] = None,
167171
seed: Optional[int] = None,
172+
model: Optional[str] = None,
168173
**kwargs: Any,
169174
) -> _models.ChatCompletions: ...
170175

@@ -188,8 +193,9 @@ def complete(
188193
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
189194
] = None,
190195
seed: Optional[int] = None,
196+
model: Optional[str] = None,
191197
**kwargs: Any,
192-
) -> _models.StreamingChatCompletions: ...
198+
) -> Iterable[_models.StreamingChatCompletionsUpdate]: ...
193199

194200
@overload
195201
def complete(
@@ -211,8 +217,9 @@ def complete(
211217
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
212218
] = None,
213219
seed: Optional[int] = None,
220+
model: Optional[str] = None,
214221
**kwargs: Any,
215-
) -> Union[_models.StreamingChatCompletions, _models.ChatCompletions]:
222+
) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]:
216223
# pylint: disable=line-too-long
217224
"""Gets chat completions for the provided chat messages.
218225
Completions support a wide variety of tasks and generate text that continues from or
@@ -294,10 +301,13 @@ def complete(
294301
~azure.ai.inference.models.ChatCompletionsNamedToolSelection
295302
:keyword seed: If specified, the system will make a best effort to sample deterministically
296303
such that repeated requests with the
297-
same seed and parameters should return the same result. Determinism is not guaranteed.".
304+
same seed and parameters should return the same result. Determinism is not guaranteed.
298305
Default value is None.
299306
:paramtype seed: int
300-
:return: ChatCompletions for non-streaming, or StreamingChatCompletions for streaming.
307+
:keyword model: ID of the specific AI model to use, if more than one model is available on the
308+
endpoint. Default value is None.
309+
:paramtype model: str
310+
:return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming.
301311
:rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions
302312
:raises ~azure.core.exceptions.HttpResponseError
303313
"""
@@ -309,7 +319,7 @@ def complete(
309319
*,
310320
content_type: str = "application/json",
311321
**kwargs: Any,
312-
) -> Union[_models.StreamingChatCompletions, _models.ChatCompletions]:
322+
) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]:
313323
# pylint: disable=line-too-long
314324
"""Gets chat completions for the provided chat messages.
315325
Completions support a wide variety of tasks and generate text that continues from or
@@ -321,7 +331,7 @@ def complete(
321331
:keyword content_type: Body Parameter content-type. Content type parameter for JSON body.
322332
Default value is "application/json".
323333
:paramtype content_type: str
324-
:return: ChatCompletions for non-streaming, or StreamingChatCompletions for streaming.
334+
:return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming.
325335
:rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions
326336
:raises ~azure.core.exceptions.HttpResponseError
327337
"""
@@ -333,7 +343,7 @@ def complete(
333343
*,
334344
content_type: str = "application/json",
335345
**kwargs: Any,
336-
) -> Union[_models.StreamingChatCompletions, _models.ChatCompletions]:
346+
) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]:
337347
# pylint: disable=line-too-long
338348
# pylint: disable=too-many-locals
339349
"""Gets chat completions for the provided chat messages.
@@ -345,7 +355,7 @@ def complete(
345355
:keyword content_type: Body Parameter content-type. Content type parameter for binary body.
346356
Default value is "application/json".
347357
:paramtype content_type: str
348-
:return: ChatCompletions for non-streaming, or StreamingChatCompletions for streaming.
358+
:return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming.
349359
:rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions
350360
:raises ~azure.core.exceptions.HttpResponseError
351361
"""
@@ -370,8 +380,9 @@ def complete(
370380
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
371381
] = None,
372382
seed: Optional[int] = None,
383+
model: Optional[str] = None,
373384
**kwargs: Any,
374-
) -> Union[_models.StreamingChatCompletions, _models.ChatCompletions]:
385+
) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]:
375386
# pylint: disable=line-too-long
376387
# pylint: disable=too-many-locals
377388
"""Gets chat completions for the provided chat messages.
@@ -451,10 +462,13 @@ def complete(
451462
~azure.ai.inference.models.ChatCompletionsNamedToolSelection
452463
:keyword seed: If specified, the system will make a best effort to sample deterministically
453464
such that repeated requests with the
454-
same seed and parameters should return the same result. Determinism is not guaranteed.".
465+
same seed and parameters should return the same result. Determinism is not guaranteed.
455466
Default value is None.
456467
:paramtype seed: int
457-
:return: ChatCompletions for non-streaming, or StreamingChatCompletions for streaming.
468+
:keyword model: ID of the specific AI model to use, if more than one model is available on the
469+
endpoint. Default value is None.
470+
:paramtype model: str
471+
:return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming.
458472
:rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions
459473
:raises ~azure.core.exceptions.HttpResponseError
460474
"""
@@ -479,6 +493,7 @@ def complete(
479493
"frequency_penalty": frequency_penalty,
480494
"max_tokens": max_tokens,
481495
"messages": messages,
496+
"model": model,
482497
"presence_penalty": presence_penalty,
483498
"response_format": response_format,
484499
"seed": seed,
@@ -603,13 +618,6 @@ def embed(
603618
:keyword content_type: Body Parameter content-type. Content type parameter for JSON body.
604619
Default value is "application/json".
605620
:paramtype content_type: str
606-
:keyword extras: Extra parameters (in the form of string key-value pairs) that are not in the
607-
standard request payload.
608-
They will be passed to the service as-is in the root of the JSON request payload.
609-
How the service handles these extra parameters depends on the value of the
610-
``extra-parameters``
611-
HTTP request header. Default value is None.
612-
:paramtype extras: dict[str, str]
613621
:keyword dimensions: Optional. The number of dimensions the resulting output embeddings should
614622
have.
615623
Passing null causes the model to use its default value.
@@ -855,13 +863,6 @@ def embed(
855863
:keyword content_type: Body Parameter content-type. Content type parameter for JSON body.
856864
Default value is "application/json".
857865
:paramtype content_type: str
858-
:keyword extras: Extra parameters (in the form of string key-value pairs) that are not in the
859-
standard request payload.
860-
They will be passed to the service as-is in the root of the JSON request payload.
861-
How the service handles these extra parameters depends on the value of the
862-
``extra-parameters``
863-
HTTP request header. Default value is None.
864-
:paramtype extras: dict[str, str]
865866
:keyword dimensions: Optional. The number of dimensions the resulting output embeddings should
866867
have.
867868
Passing null causes the model to use its default value.

sdk/ai/azure-ai-inference/azure/ai/inference/aio/_client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from copy import deepcopy
1010
from typing import Any, Awaitable, TYPE_CHECKING, Union
11+
from typing_extensions import Self
1112

1213
from azure.core import AsyncPipelineClient
1314
from azure.core.credentials import AzureKeyCredential
@@ -105,7 +106,7 @@ def send_request(
105106
async def close(self) -> None:
106107
await self._client.close()
107108

108-
async def __aenter__(self) -> "ChatCompletionsClient":
109+
async def __aenter__(self) -> Self:
109110
await self._client.__aenter__()
110111
return self
111112

@@ -187,7 +188,7 @@ def send_request(
187188
async def close(self) -> None:
188189
await self._client.close()
189190

190-
async def __aenter__(self) -> "EmbeddingsClient":
191+
async def __aenter__(self) -> Self:
191192
await self._client.__aenter__()
192193
return self
193194

@@ -269,7 +270,7 @@ def send_request(
269270
async def close(self) -> None:
270271
await self._client.close()
271272

272-
async def __aenter__(self) -> "ImageEmbeddingsClient":
273+
async def __aenter__(self) -> Self:
273274
await self._client.__aenter__()
274275
return self
275276

0 commit comments

Comments
 (0)