Skip to content

Commit bc4d415

Browse files
authored
[Azure.AI.project] Implement ConnectionProvider abstraction and add extension methods (Azure#48399)
[Azure.AI.project] Implement ConnectionProvider abstraction and add extension methods
1 parent ac0eeb2 commit bc4d415

File tree

47 files changed

+1251
-212
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1251
-212
lines changed

sdk/ai/Azure.AI.Inference/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
### Features Added
66

7+
- Added extension methods to get `ChatCompletionsClient` and `EmbeddingsClient` using [AIProjectClient](https://learn.microsoft.com/dotnet/api/azure.ai.projects.aiprojectclient?view=azure-dotnet-preview).
8+
79
### Breaking Changes
810

911
### Bugs Fixed

sdk/ai/Azure.AI.Inference/api/Azure.AI.Inference.net8.0.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
namespace Azure.AI.Inference
22
{
3+
public static partial class AIInferenceExtensions
4+
{
5+
public static Azure.AI.Inference.ChatCompletionsClient GetChatCompletionsClient(this System.ClientModel.Primitives.ConnectionProvider provider) { throw null; }
6+
public static Azure.AI.Inference.EmbeddingsClient GetEmbeddingsClient(this System.ClientModel.Primitives.ConnectionProvider provider) { throw null; }
7+
}
38
public static partial class AIInferenceModelFactory
49
{
510
public static Azure.AI.Inference.ChatChoice ChatChoice(int index = 0, Azure.AI.Inference.CompletionsFinishReason? finishReason = default(Azure.AI.Inference.CompletionsFinishReason?), Azure.AI.Inference.ChatResponseMessage message = null) { throw null; }

sdk/ai/Azure.AI.Inference/api/Azure.AI.Inference.netstandard2.0.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
namespace Azure.AI.Inference
22
{
3+
public static partial class AIInferenceExtensions
4+
{
5+
public static Azure.AI.Inference.ChatCompletionsClient GetChatCompletionsClient(this System.ClientModel.Primitives.ConnectionProvider provider) { throw null; }
6+
public static Azure.AI.Inference.EmbeddingsClient GetEmbeddingsClient(this System.ClientModel.Primitives.ConnectionProvider provider) { throw null; }
7+
}
38
public static partial class AIInferenceModelFactory
49
{
510
public static Azure.AI.Inference.ChatChoice ChatChoice(int index = 0, Azure.AI.Inference.CompletionsFinishReason? finishReason = default(Azure.AI.Inference.CompletionsFinishReason?), Azure.AI.Inference.ChatResponseMessage message = null) { throw null; }

sdk/ai/Azure.AI.Inference/src/Azure.AI.Inference.csproj

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,8 @@
1919
<PackageReference Include="Azure.Core" />
2020
<PackageReference Include="System.Text.Json" />
2121
</ItemGroup>
22+
<ItemGroup>
23+
<ProjectReference Include="..\..\..\core\System.ClientModel\src\System.ClientModel.csproj" />
24+
</ItemGroup>
2225

2326
</Project>
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#nullable enable
5+
6+
using System;
7+
using System.ClientModel.Primitives;
8+
using Azure.Core;
9+
10+
namespace Azure.AI.Inference
11+
{
12+
/// <summary>
13+
/// The Azure AI Inference extensions.
14+
/// </summary>
15+
public static class AIInferenceExtensions
16+
{
17+
/// <summary>
18+
/// Gets the chat completion client.
19+
/// </summary>
20+
/// <param name="provider"></param>
21+
/// <returns></returns>
22+
public static ChatCompletionsClient GetChatCompletionsClient(this ConnectionProvider provider)
23+
{
24+
ChatCompletionsClient chatClient = provider.Subclients.GetClient(() => CreateChatCompletionsClient(provider), null);
25+
return chatClient;
26+
}
27+
28+
private static ChatCompletionsClient CreateChatCompletionsClient(this ConnectionProvider provider)
29+
{
30+
ClientConnection connection = provider.GetConnection(typeof(ChatCompletionsClient).FullName!);
31+
if (!connection.TryGetLocatorAsUri(out Uri? uri) || uri is null)
32+
{
33+
throw new InvalidOperationException("Invalid URI.");
34+
}
35+
return connection.Authentication == ClientAuthenticationMethod.Credential
36+
? new ChatCompletionsClient(uri, connection.Credential as TokenCredential)
37+
: new ChatCompletionsClient(uri, new AzureKeyCredential(connection.ApiKeyCredential!));
38+
}
39+
40+
/// <summary>
41+
/// Gets the embeddings client.
42+
/// </summary>
43+
/// <param name="provider"></param>
44+
/// <returns></returns>
45+
public static EmbeddingsClient GetEmbeddingsClient(this ConnectionProvider provider)
46+
{
47+
EmbeddingsClient embeddingsClient = provider.Subclients.GetClient(() => CreateEmbeddingsClient(provider), null);
48+
return embeddingsClient;
49+
}
50+
51+
private static EmbeddingsClient CreateEmbeddingsClient(this ConnectionProvider provider)
52+
{
53+
ClientConnection connection = provider.GetConnection(typeof(ChatCompletionsClient).FullName!);
54+
if (!connection.TryGetLocatorAsUri(out Uri? uri) || uri is null)
55+
{
56+
throw new InvalidOperationException("Invalid URI.");
57+
}
58+
return connection.Authentication == ClientAuthenticationMethod.Credential
59+
? new EmbeddingsClient(uri, connection.Credential as TokenCredential)
60+
: new EmbeddingsClient(uri, new AzureKeyCredential(connection.ApiKeyCredential!));
61+
}
62+
}
63+
}

sdk/ai/Azure.AI.Projects/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
### Features Added
66

7+
* Added `ConnectionProvider` abstraction in `AIProjectClient` to enable seamless connectivity with Azure OpenAI, Inference, and Search SDKs.
8+
79
### Breaking Changes
810

911
### Bugs Fixed

sdk/ai/Azure.AI.Projects/api/Azure.AI.Projects.net8.0.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -457,18 +457,18 @@ public static partial class AIClientModelFactory
457457
public static Azure.AI.Projects.ThreadMessage ThreadMessage(string id = null, System.DateTimeOffset createdAt = default(System.DateTimeOffset), string threadId = null, Azure.AI.Projects.MessageStatus status = default(Azure.AI.Projects.MessageStatus), Azure.AI.Projects.MessageIncompleteDetails incompleteDetails = null, System.DateTimeOffset? completedAt = default(System.DateTimeOffset?), System.DateTimeOffset? incompleteAt = default(System.DateTimeOffset?), Azure.AI.Projects.MessageRole role = default(Azure.AI.Projects.MessageRole), System.Collections.Generic.IEnumerable<Azure.AI.Projects.MessageContent> contentItems = null, string agentId = null, string runId = null, System.Collections.Generic.IEnumerable<Azure.AI.Projects.MessageAttachment> attachments = null, System.Collections.Generic.IDictionary<string, string> metadata = null) { throw null; }
458458
public static Azure.AI.Projects.ThreadRun ThreadRun(string id = null, string threadId = null, string agentId = null, Azure.AI.Projects.RunStatus status = default(Azure.AI.Projects.RunStatus), Azure.AI.Projects.RequiredAction requiredAction = null, Azure.AI.Projects.RunError lastError = null, string model = null, string instructions = null, System.Collections.Generic.IEnumerable<Azure.AI.Projects.ToolDefinition> tools = null, System.DateTimeOffset createdAt = default(System.DateTimeOffset), System.DateTimeOffset? expiresAt = default(System.DateTimeOffset?), System.DateTimeOffset? startedAt = default(System.DateTimeOffset?), System.DateTimeOffset? completedAt = default(System.DateTimeOffset?), System.DateTimeOffset? cancelledAt = default(System.DateTimeOffset?), System.DateTimeOffset? failedAt = default(System.DateTimeOffset?), Azure.AI.Projects.IncompleteRunDetails incompleteDetails = null, Azure.AI.Projects.RunCompletionUsage usage = null, float? temperature = default(float?), float? topP = default(float?), int? maxPromptTokens = default(int?), int? maxCompletionTokens = default(int?), Azure.AI.Projects.TruncationObject truncationStrategy = null, System.BinaryData toolChoice = null, System.BinaryData responseFormat = null, System.Collections.Generic.IReadOnlyDictionary<string, string> metadata = null, Azure.AI.Projects.UpdateToolResourcesOptions toolResources = null, bool? parallelToolCalls = default(bool?)) { throw null; }
459459
}
460-
public partial class AIProjectClient
460+
public partial class AIProjectClient : System.ClientModel.Primitives.ConnectionProvider
461461
{
462462
protected AIProjectClient() { }
463-
public AIProjectClient(string connectionString, Azure.Core.TokenCredential credential) { }
463+
public AIProjectClient(string connectionString, Azure.Core.TokenCredential credential = null) { }
464464
public AIProjectClient(string connectionString, Azure.Core.TokenCredential credential, Azure.AI.Projects.AIProjectClientOptions options) { }
465465
public AIProjectClient(System.Uri endpoint, string subscriptionId, string resourceGroupName, string projectName, Azure.Core.TokenCredential credential) { }
466466
public AIProjectClient(System.Uri endpoint, string subscriptionId, string resourceGroupName, string projectName, Azure.Core.TokenCredential credential, Azure.AI.Projects.AIProjectClientOptions options) { }
467467
public virtual Azure.Core.Pipeline.HttpPipeline Pipeline { get { throw null; } }
468468
public virtual Azure.AI.Projects.AgentsClient GetAgentsClient(string apiVersion = "2024-07-01-preview") { throw null; }
469-
public virtual Azure.AI.Inference.ChatCompletionsClient GetChatCompletionsClient() { throw null; }
469+
public override System.Collections.Generic.IEnumerable<System.ClientModel.Primitives.ClientConnection> GetAllConnections() { throw null; }
470+
public override System.ClientModel.Primitives.ClientConnection GetConnection(string connectionId) { throw null; }
470471
public virtual Azure.AI.Projects.ConnectionsClient GetConnectionsClient(string apiVersion = "2024-07-01-preview") { throw null; }
471-
public virtual Azure.AI.Inference.EmbeddingsClient GetEmbeddingsClient() { throw null; }
472472
public virtual Azure.AI.Projects.EvaluationsClient GetEvaluationsClient(string apiVersion = "2024-07-01-preview") { throw null; }
473473
public virtual Azure.AI.Projects.TelemetryClient GetTelemetryClient(string apiVersion = "2024-07-01-preview") { throw null; }
474474
}

sdk/ai/Azure.AI.Projects/api/Azure.AI.Projects.netstandard2.0.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -457,18 +457,18 @@ public static partial class AIClientModelFactory
457457
public static Azure.AI.Projects.ThreadMessage ThreadMessage(string id = null, System.DateTimeOffset createdAt = default(System.DateTimeOffset), string threadId = null, Azure.AI.Projects.MessageStatus status = default(Azure.AI.Projects.MessageStatus), Azure.AI.Projects.MessageIncompleteDetails incompleteDetails = null, System.DateTimeOffset? completedAt = default(System.DateTimeOffset?), System.DateTimeOffset? incompleteAt = default(System.DateTimeOffset?), Azure.AI.Projects.MessageRole role = default(Azure.AI.Projects.MessageRole), System.Collections.Generic.IEnumerable<Azure.AI.Projects.MessageContent> contentItems = null, string agentId = null, string runId = null, System.Collections.Generic.IEnumerable<Azure.AI.Projects.MessageAttachment> attachments = null, System.Collections.Generic.IDictionary<string, string> metadata = null) { throw null; }
458458
public static Azure.AI.Projects.ThreadRun ThreadRun(string id = null, string threadId = null, string agentId = null, Azure.AI.Projects.RunStatus status = default(Azure.AI.Projects.RunStatus), Azure.AI.Projects.RequiredAction requiredAction = null, Azure.AI.Projects.RunError lastError = null, string model = null, string instructions = null, System.Collections.Generic.IEnumerable<Azure.AI.Projects.ToolDefinition> tools = null, System.DateTimeOffset createdAt = default(System.DateTimeOffset), System.DateTimeOffset? expiresAt = default(System.DateTimeOffset?), System.DateTimeOffset? startedAt = default(System.DateTimeOffset?), System.DateTimeOffset? completedAt = default(System.DateTimeOffset?), System.DateTimeOffset? cancelledAt = default(System.DateTimeOffset?), System.DateTimeOffset? failedAt = default(System.DateTimeOffset?), Azure.AI.Projects.IncompleteRunDetails incompleteDetails = null, Azure.AI.Projects.RunCompletionUsage usage = null, float? temperature = default(float?), float? topP = default(float?), int? maxPromptTokens = default(int?), int? maxCompletionTokens = default(int?), Azure.AI.Projects.TruncationObject truncationStrategy = null, System.BinaryData toolChoice = null, System.BinaryData responseFormat = null, System.Collections.Generic.IReadOnlyDictionary<string, string> metadata = null, Azure.AI.Projects.UpdateToolResourcesOptions toolResources = null, bool? parallelToolCalls = default(bool?)) { throw null; }
459459
}
460-
public partial class AIProjectClient
460+
public partial class AIProjectClient : System.ClientModel.Primitives.ConnectionProvider
461461
{
462462
protected AIProjectClient() { }
463-
public AIProjectClient(string connectionString, Azure.Core.TokenCredential credential) { }
463+
public AIProjectClient(string connectionString, Azure.Core.TokenCredential credential = null) { }
464464
public AIProjectClient(string connectionString, Azure.Core.TokenCredential credential, Azure.AI.Projects.AIProjectClientOptions options) { }
465465
public AIProjectClient(System.Uri endpoint, string subscriptionId, string resourceGroupName, string projectName, Azure.Core.TokenCredential credential) { }
466466
public AIProjectClient(System.Uri endpoint, string subscriptionId, string resourceGroupName, string projectName, Azure.Core.TokenCredential credential, Azure.AI.Projects.AIProjectClientOptions options) { }
467467
public virtual Azure.Core.Pipeline.HttpPipeline Pipeline { get { throw null; } }
468468
public virtual Azure.AI.Projects.AgentsClient GetAgentsClient(string apiVersion = "2024-07-01-preview") { throw null; }
469-
public virtual Azure.AI.Inference.ChatCompletionsClient GetChatCompletionsClient() { throw null; }
469+
public override System.Collections.Generic.IEnumerable<System.ClientModel.Primitives.ClientConnection> GetAllConnections() { throw null; }
470+
public override System.ClientModel.Primitives.ClientConnection GetConnection(string connectionId) { throw null; }
470471
public virtual Azure.AI.Projects.ConnectionsClient GetConnectionsClient(string apiVersion = "2024-07-01-preview") { throw null; }
471-
public virtual Azure.AI.Inference.EmbeddingsClient GetEmbeddingsClient() { throw null; }
472472
public virtual Azure.AI.Projects.EvaluationsClient GetEvaluationsClient(string apiVersion = "2024-07-01-preview") { throw null; }
473473
public virtual Azure.AI.Projects.TelemetryClient GetTelemetryClient(string apiVersion = "2024-07-01-preview") { throw null; }
474474
}

sdk/ai/Azure.AI.Projects/src/Azure.AI.Projects.csproj

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414

1515
<ItemGroup>
1616
<PackageReference Include="Azure.Core" />
17+
<PackageReference Include="Azure.Identity" />
1718
<PackageReference Include="System.Text.Json" />
18-
<PackageReference Include="Azure.AI.Inference" />
19+
</ItemGroup>
20+
21+
<ItemGroup>
22+
<ProjectReference Include="..\..\..\core\System.ClientModel\src\System.ClientModel.csproj" />
1923
</ItemGroup>
2024

2125
<!-- Shared source from Azure.Core -->

0 commit comments

Comments
 (0)