Skip to content

Commit eac4d87

Browse files
authored
[System.ClientModel] Support all properties needed to create a client in ClientCache and consolidate ClientConnection.Credential usage (Azure#49315)
[System.ClientModel] Add support for client options in ClientCache
1 parent 8054540 commit eac4d87

File tree

19 files changed

+269
-189
lines changed

19 files changed

+269
-189
lines changed

sdk/ai/Azure.AI.Inference/src/Customized/AIInferenceExtensions.cs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,26 @@ public static class AIInferenceExtensions
2121
/// <returns></returns>
2222
public static ChatCompletionsClient GetChatCompletionsClient(this ConnectionProvider provider)
2323
{
24-
ChatCompletionsClient chatClient = provider.Subclients.GetClient(() => CreateChatCompletionsClient(provider), null);
24+
ChatCompletionsClientKey chatCompletionsClientKey = new();
25+
ChatCompletionsClient chatClient = provider.Subclients.GetClient(chatCompletionsClientKey, () => CreateChatCompletionsClient(provider));
2526
return chatClient;
2627
}
2728

2829
private static ChatCompletionsClient CreateChatCompletionsClient(this ConnectionProvider provider)
2930
{
3031
ClientConnection connection = provider.GetConnection(typeof(ChatCompletionsClient).FullName!);
32+
3133
if (!connection.TryGetLocatorAsUri(out Uri? uri) || uri is null)
3234
{
3335
throw new InvalidOperationException("Invalid URI.");
3436
}
35-
return connection.Authentication == ClientAuthenticationMethod.Credential
36-
? new ChatCompletionsClient(uri, connection.Credential as TokenCredential)
37-
: new ChatCompletionsClient(uri, new AzureKeyCredential(connection.ApiKeyCredential!));
37+
38+
return connection.CredentialKind switch
39+
{
40+
CredentialKind.ApiKeyString => new ChatCompletionsClient(uri, new AzureKeyCredential((string)connection.Credential!)),
41+
CredentialKind.TokenCredential => new ChatCompletionsClient(uri, (TokenCredential)connection.Credential!),
42+
_ => throw new InvalidOperationException($"Unsupported credential kind: {connection.CredentialKind}")
43+
};
3844
}
3945

4046
/// <summary>
@@ -44,20 +50,30 @@ private static ChatCompletionsClient CreateChatCompletionsClient(this Connection
4450
/// <returns></returns>
4551
public static EmbeddingsClient GetEmbeddingsClient(this ConnectionProvider provider)
4652
{
47-
EmbeddingsClient embeddingsClient = provider.Subclients.GetClient(() => CreateEmbeddingsClient(provider), null);
53+
EmbeddingsClientKey embeddingsClientKey = new();
54+
EmbeddingsClient embeddingsClient = provider.Subclients.GetClient(embeddingsClientKey, () => CreateEmbeddingsClient(provider));
4855
return embeddingsClient;
4956
}
5057

5158
private static EmbeddingsClient CreateEmbeddingsClient(this ConnectionProvider provider)
5259
{
5360
ClientConnection connection = provider.GetConnection(typeof(ChatCompletionsClient).FullName!);
61+
5462
if (!connection.TryGetLocatorAsUri(out Uri? uri) || uri is null)
5563
{
5664
throw new InvalidOperationException("Invalid URI.");
5765
}
58-
return connection.Authentication == ClientAuthenticationMethod.Credential
59-
? new EmbeddingsClient(uri, connection.Credential as TokenCredential)
60-
: new EmbeddingsClient(uri, new AzureKeyCredential(connection.ApiKeyCredential!));
66+
67+
return connection.CredentialKind switch
68+
{
69+
CredentialKind.ApiKeyString => new EmbeddingsClient(uri, new AzureKeyCredential((string)connection.Credential!)),
70+
CredentialKind.TokenCredential => new EmbeddingsClient(uri, (TokenCredential)connection.Credential!),
71+
_ => throw new InvalidOperationException($"Unsupported credential kind: {connection.CredentialKind}")
72+
};
6173
}
74+
75+
private record ChatCompletionsClientKey() : IEquatable<object>;
76+
77+
private record EmbeddingsClientKey() : IEquatable<object>;
6278
}
6379
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#if !NET5_0_OR_GREATER
5+
6+
using System.ComponentModel;
7+
namespace System.Runtime.CompilerServices;
8+
9+
[EditorBrowsable(EditorBrowsableState.Never)]
10+
internal static class IsExternalInit { }
11+
12+
#endif // !NET5_0_OR_GREATER

sdk/ai/Azure.AI.Projects/src/Azure.AI.Projects.csproj

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
<ItemGroup>
1616
<PackageReference Include="Azure.Core" />
1717
<PackageReference Include="Azure.Identity" />
18-
<PackageReference Include="System.ClientModel" VersionOverride="1.4.0-beta.3" />
18+
</ItemGroup>
19+
20+
<ItemGroup>
21+
<ProjectReference Include="..\..\..\core\System.ClientModel\src\System.ClientModel.csproj" />
1922
</ItemGroup>
2023

2124
<!-- Shared source from Azure.Core -->

sdk/ai/Azure.AI.Projects/src/Custom/ConnectionCacheManager.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public ClientConnection GetConnection(string connectionId)
5252
throw new ArgumentException($"The API key is missing or invalid for {connectionId}.");
5353
}
5454

55-
var newConnection = new ClientConnection(connectionId, apiKeyAuthProperties.Target, apiKeyAuthProperties.Credentials.Key);
55+
var newConnection = new ClientConnection(connectionId, apiKeyAuthProperties.Target, apiKeyAuthProperties.Credentials.Key, CredentialKind.ApiKeyString);
5656
return _connections.GetOrAdd(connectionId, newConnection);
5757
}
5858
else if (connection.Properties.AuthType == AuthenticationType.EntraId)
@@ -63,7 +63,7 @@ public ClientConnection GetConnection(string connectionId)
6363
throw new ArgumentException($"The AAD authentication target URI is missing or invalid for {connectionId}.");
6464
}
6565

66-
var newConnection = new ClientConnection(connectionId, aadAuthProperties.Target, _tokenCredential);
66+
var newConnection = new ClientConnection(connectionId, aadAuthProperties.Target, _tokenCredential, CredentialKind.TokenCredential);
6767
return _connections.GetOrAdd(connectionId, newConnection);
6868
}
6969

sdk/ai/Azure.AI.Projects/tests/Azure.AI.Projects.Tests.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
<!--TODO: remove VersionOverride https://github.com/Azure/azure-sdk-for-net/issues/49730 -->
3636
<PackageReference Include="Azure.AI.OpenAI" VersionOverride="2.2.0-beta.4"/>
3737
<PackageReference Include="Azure.Search.Documents" VersionOverride="11.7.0-beta.3" />
38+
<PackageReference Include="System.Text.Json" />
3839
</ItemGroup>
3940

4041
</Project>

sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,28 +82,20 @@ protected AsyncCollectionResult() { }
8282
public abstract System.ClientModel.ContinuationToken? GetContinuationToken(System.ClientModel.ClientResult page);
8383
public abstract System.Collections.Generic.IAsyncEnumerable<System.ClientModel.ClientResult> GetRawPagesAsync();
8484
}
85-
public enum ClientAuthenticationMethod
86-
{
87-
Credential = 0,
88-
ApiKey = 1,
89-
NoAuth = 2,
90-
}
9185
public partial class ClientCache
9286
{
9387
public ClientCache(int maxSize = 100) { }
94-
public T GetClient<T>(System.Func<T> createClient, string? id) where T : class { throw null; }
88+
public T GetClient<T>(System.IEquatable<object> clientId, System.Func<T> createClient) where T : class { throw null; }
9589
}
9690
[System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)]
9791
public readonly partial struct ClientConnection
9892
{
9993
private readonly object _dummy;
10094
private readonly int _dummyPrimitive;
10195
public ClientConnection(string id, string locator) { throw null; }
102-
public ClientConnection(string id, string locator, object credential) { throw null; }
103-
public ClientConnection(string id, string locator, string apiKey) { throw null; }
104-
public string? ApiKeyCredential { get { throw null; } }
105-
public System.ClientModel.Primitives.ClientAuthenticationMethod Authentication { get { throw null; } }
96+
public ClientConnection(string id, string locator, object credential, System.ClientModel.Primitives.CredentialKind credentialKind) { throw null; }
10697
public object? Credential { get { throw null; } }
98+
public System.ClientModel.Primitives.CredentialKind CredentialKind { get { throw null; } }
10799
public string Id { get { throw null; } }
108100
public string Locator { get { throw null; } }
109101
public override string ToString() { throw null; }
@@ -188,6 +180,12 @@ protected ConnectionProvider(int maxCacheSize = 100) { }
188180
public abstract System.Collections.Generic.IEnumerable<System.ClientModel.Primitives.ClientConnection> GetAllConnections();
189181
public abstract System.ClientModel.Primitives.ClientConnection GetConnection(string connectionId);
190182
}
183+
public enum CredentialKind
184+
{
185+
None = 0,
186+
ApiKeyString = 1,
187+
TokenCredential = 2,
188+
}
191189
public partial class HttpClientPipelineTransport : System.ClientModel.Primitives.PipelineTransport, System.IDisposable
192190
{
193191
public HttpClientPipelineTransport() { }

sdk/core/System.ClientModel/api/System.ClientModel.net8.0.cs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,28 +82,20 @@ protected AsyncCollectionResult() { }
8282
public abstract System.ClientModel.ContinuationToken? GetContinuationToken(System.ClientModel.ClientResult page);
8383
public abstract System.Collections.Generic.IAsyncEnumerable<System.ClientModel.ClientResult> GetRawPagesAsync();
8484
}
85-
public enum ClientAuthenticationMethod
86-
{
87-
Credential = 0,
88-
ApiKey = 1,
89-
NoAuth = 2,
90-
}
9185
public partial class ClientCache
9286
{
9387
public ClientCache(int maxSize = 100) { }
94-
public T GetClient<T>(System.Func<T> createClient, string? id) where T : class { throw null; }
88+
public T GetClient<T>(System.IEquatable<object> clientId, System.Func<T> createClient) where T : class { throw null; }
9589
}
9690
[System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)]
9791
public readonly partial struct ClientConnection
9892
{
9993
private readonly object _dummy;
10094
private readonly int _dummyPrimitive;
10195
public ClientConnection(string id, string locator) { throw null; }
102-
public ClientConnection(string id, string locator, object credential) { throw null; }
103-
public ClientConnection(string id, string locator, string apiKey) { throw null; }
104-
public string? ApiKeyCredential { get { throw null; } }
105-
public System.ClientModel.Primitives.ClientAuthenticationMethod Authentication { get { throw null; } }
96+
public ClientConnection(string id, string locator, object credential, System.ClientModel.Primitives.CredentialKind credentialKind) { throw null; }
10697
public object? Credential { get { throw null; } }
98+
public System.ClientModel.Primitives.CredentialKind CredentialKind { get { throw null; } }
10799
public string Id { get { throw null; } }
108100
public string Locator { get { throw null; } }
109101
public override string ToString() { throw null; }
@@ -188,6 +180,12 @@ protected ConnectionProvider(int maxCacheSize = 100) { }
188180
public abstract System.Collections.Generic.IEnumerable<System.ClientModel.Primitives.ClientConnection> GetAllConnections();
189181
public abstract System.ClientModel.Primitives.ClientConnection GetConnection(string connectionId);
190182
}
183+
public enum CredentialKind
184+
{
185+
None = 0,
186+
ApiKeyString = 1,
187+
TokenCredential = 2,
188+
}
191189
public partial class HttpClientPipelineTransport : System.ClientModel.Primitives.PipelineTransport, System.IDisposable
192190
{
193191
public HttpClientPipelineTransport() { }

sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,28 +82,20 @@ protected AsyncCollectionResult() { }
8282
public abstract System.ClientModel.ContinuationToken? GetContinuationToken(System.ClientModel.ClientResult page);
8383
public abstract System.Collections.Generic.IAsyncEnumerable<System.ClientModel.ClientResult> GetRawPagesAsync();
8484
}
85-
public enum ClientAuthenticationMethod
86-
{
87-
Credential = 0,
88-
ApiKey = 1,
89-
NoAuth = 2,
90-
}
9185
public partial class ClientCache
9286
{
9387
public ClientCache(int maxSize = 100) { }
94-
public T GetClient<T>(System.Func<T> createClient, string? id) where T : class { throw null; }
88+
public T GetClient<T>(System.IEquatable<object> clientId, System.Func<T> createClient) where T : class { throw null; }
9589
}
9690
[System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)]
9791
public readonly partial struct ClientConnection
9892
{
9993
private readonly object _dummy;
10094
private readonly int _dummyPrimitive;
10195
public ClientConnection(string id, string locator) { throw null; }
102-
public ClientConnection(string id, string locator, object credential) { throw null; }
103-
public ClientConnection(string id, string locator, string apiKey) { throw null; }
104-
public string? ApiKeyCredential { get { throw null; } }
105-
public System.ClientModel.Primitives.ClientAuthenticationMethod Authentication { get { throw null; } }
96+
public ClientConnection(string id, string locator, object credential, System.ClientModel.Primitives.CredentialKind credentialKind) { throw null; }
10697
public object? Credential { get { throw null; } }
98+
public System.ClientModel.Primitives.CredentialKind CredentialKind { get { throw null; } }
10799
public string Id { get { throw null; } }
108100
public string Locator { get { throw null; } }
109101
public override string ToString() { throw null; }
@@ -188,6 +180,12 @@ protected ConnectionProvider(int maxCacheSize = 100) { }
188180
public abstract System.Collections.Generic.IEnumerable<System.ClientModel.Primitives.ClientConnection> GetAllConnections();
189181
public abstract System.ClientModel.Primitives.ClientConnection GetConnection(string connectionId);
190182
}
183+
public enum CredentialKind
184+
{
185+
None = 0,
186+
ApiKeyString = 1,
187+
TokenCredential = 2,
188+
}
191189
public partial class HttpClientPipelineTransport : System.ClientModel.Primitives.PipelineTransport, System.IDisposable
192190
{
193191
public HttpClientPipelineTransport() { }

sdk/core/System.ClientModel/src/Convenience/ClientCache.cs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
using System.Collections.Generic;
54
using System.Diagnostics;
6-
using System.Linq;
7-
using System.Threading;
85

96
namespace System.ClientModel.Primitives;
107

@@ -14,7 +11,7 @@ namespace System.ClientModel.Primitives;
1411
/// </summary>
1512
public class ClientCache
1613
{
17-
private readonly Dictionary<(Type, string), ClientEntry> _clients = new();
14+
private readonly Dictionary<IEquatable<object>, ClientEntry> _clients = new();
1815
private readonly ReaderWriterLockSlim _lock = new(LockRecursionPolicy.SupportsRecursion);
1916

2017
private readonly int _maxSize;
@@ -33,26 +30,30 @@ public ClientCache(int maxSize = 100)
3330
/// Updates the last-used timestamp on every hit.
3431
/// </summary>
3532
/// <typeparam name="T">The type of the client.</typeparam>
33+
/// <param name="clientId">An equality-comparable key representing the client configuration.</param>
3634
/// <param name="createClient">A factory function to create the client if not cached.</param>
37-
/// <param name="id">An identifier for the client instance.</param>
3835
/// <returns>The cached or newly created client instance.</returns>
39-
public T GetClient<T>(Func<T> createClient, string? id) where T : class
36+
public T GetClient<T>(IEquatable<object> clientId, Func<T> createClient) where T : class
4037
{
41-
(Type, string) key = (typeof(T), id ?? string.Empty);
42-
4338
// If the client exists, update its timestamp.
44-
if (_clients.TryGetValue(key, out var cached))
39+
if (_clients.TryGetValue(clientId, out var cached))
4540
{
4641
cached.LastUsed = Stopwatch.GetTimestamp();
47-
return (T)cached.Client;
42+
43+
if (cached.Client is T typedClient)
44+
{
45+
return typedClient;
46+
}
47+
48+
throw new InvalidOperationException($"The clientId is associated with a client of type '{cached.Client.GetType()}', not '{typeof(T)}'.");
4849
}
4950

5051
// Client not found in cache, create a new one.
5152
_lock.EnterWriteLock();
5253
try
5354
{
5455
T created = createClient();
55-
_clients[key] = new ClientEntry(created, Stopwatch.GetTimestamp());
56+
_clients[clientId] = new ClientEntry(created, Stopwatch.GetTimestamp());
5657

5758
// After insertion, if cache exceeds the limit, perform cleanup.
5859
if (_clients.Count > _maxSize)

0 commit comments

Comments
 (0)