Skip to content

Commit a24ddc8

Browse files
authored
[SCM] Address connections API feedback to prepare for release. (Azure#49761)
[SCM] Address connections API feedback to prepare for release.
1 parent eba19a8 commit a24ddc8

File tree

11 files changed

+75
-53
lines changed

11 files changed

+75
-53
lines changed

sdk/ai/Azure.AI.Inference/src/Customized/AIInferenceExtensions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ private static EmbeddingsClient CreateEmbeddingsClient(this ConnectionProvider p
7272
};
7373
}
7474

75-
private record ChatCompletionsClientKey() : IEquatable<object>;
75+
private record ChatCompletionsClientKey();
7676

77-
private record EmbeddingsClientKey() : IEquatable<object>;
77+
private record EmbeddingsClientKey();
7878
}
7979
}

sdk/ai/Azure.AI.Projects/src/Custom/AIProjectClient.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ public partial class AIProjectClient : ConnectionProvider
1717
private readonly ConnectionCacheManager _cacheManager;
1818
private readonly ConnectionsClient _connectionsClient;
1919

20+
/// <summary> Initializes a new instance of AIProjectClient for mocking. </summary>
21+
protected AIProjectClient() : base(maxCacheSize: 100)
22+
{
23+
}
24+
2025
/// <summary> Initializes a new instance of AzureAIClient. </summary>
2126
/// <param name="connectionString">The Azure AI Foundry project connection string, in the form `endpoint;subscription_id;resource_group_name;project_name`.</param>
2227
/// <param name="credential"> A credential used to authenticate to an Azure Service. </param>

sdk/ai/Azure.AI.Projects/src/Generated/AIProjectClient.cs

Lines changed: 0 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ protected AsyncCollectionResult() { }
8484
}
8585
public partial class ClientCache
8686
{
87-
public ClientCache(int maxSize = 100) { }
88-
public T GetClient<T>(System.IEquatable<object> clientId, System.Func<T> createClient) where T : class { throw null; }
87+
public ClientCache(int maxSize) { }
88+
public T GetClient<T>(object clientId, System.Func<T> createClient) where T : class { throw null; }
8989
}
9090
[System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)]
9191
public readonly partial struct ClientConnection
@@ -174,7 +174,7 @@ public void AddRange(System.Collections.Generic.IEnumerable<System.ClientModel.P
174174
}
175175
public abstract partial class ConnectionProvider
176176
{
177-
protected ConnectionProvider(int maxCacheSize = 100) { }
177+
protected ConnectionProvider(int maxCacheSize) { }
178178
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
179179
public System.ClientModel.Primitives.ClientCache Subclients { get { throw null; } }
180180
public abstract System.Collections.Generic.IEnumerable<System.ClientModel.Primitives.ClientConnection> GetAllConnections();

sdk/core/System.ClientModel/api/System.ClientModel.net8.0.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ protected AsyncCollectionResult() { }
8484
}
8585
public partial class ClientCache
8686
{
87-
public ClientCache(int maxSize = 100) { }
88-
public T GetClient<T>(System.IEquatable<object> clientId, System.Func<T> createClient) where T : class { throw null; }
87+
public ClientCache(int maxSize) { }
88+
public T GetClient<T>(object clientId, System.Func<T> createClient) where T : class { throw null; }
8989
}
9090
[System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)]
9191
public readonly partial struct ClientConnection
@@ -174,7 +174,7 @@ public void AddRange(System.Collections.Generic.IEnumerable<System.ClientModel.P
174174
}
175175
public abstract partial class ConnectionProvider
176176
{
177-
protected ConnectionProvider(int maxCacheSize = 100) { }
177+
protected ConnectionProvider(int maxCacheSize) { }
178178
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
179179
public System.ClientModel.Primitives.ClientCache Subclients { get { throw null; } }
180180
public abstract System.Collections.Generic.IEnumerable<System.ClientModel.Primitives.ClientConnection> GetAllConnections();

sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ protected AsyncCollectionResult() { }
8484
}
8585
public partial class ClientCache
8686
{
87-
public ClientCache(int maxSize = 100) { }
88-
public T GetClient<T>(System.IEquatable<object> clientId, System.Func<T> createClient) where T : class { throw null; }
87+
public ClientCache(int maxSize) { }
88+
public T GetClient<T>(object clientId, System.Func<T> createClient) where T : class { throw null; }
8989
}
9090
[System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)]
9191
public readonly partial struct ClientConnection
@@ -174,7 +174,7 @@ public void AddRange(System.Collections.Generic.IEnumerable<System.ClientModel.P
174174
}
175175
public abstract partial class ConnectionProvider
176176
{
177-
protected ConnectionProvider(int maxCacheSize = 100) { }
177+
protected ConnectionProvider(int maxCacheSize) { }
178178
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
179179
public System.ClientModel.Primitives.ClientCache Subclients { get { throw null; } }
180180
public abstract System.Collections.Generic.IEnumerable<System.ClientModel.Primitives.ClientConnection> GetAllConnections();

sdk/core/System.ClientModel/src/Convenience/ClientCache.cs

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace System.ClientModel.Primitives;
1111
/// </summary>
1212
public class ClientCache
1313
{
14-
private readonly Dictionary<IEquatable<object>, ClientEntry> _clients = new();
14+
private readonly Dictionary<object, ClientEntry> _clients = new();
1515
private readonly ReaderWriterLockSlim _lock = new(LockRecursionPolicy.SupportsRecursion);
1616

1717
private readonly int _maxSize;
@@ -20,7 +20,7 @@ public class ClientCache
2020
/// Initializes the ClientCache with a configurable cache size.
2121
/// </summary>
2222
/// <param name="maxSize">The maximum number of clients to store in the cache.</param>
23-
public ClientCache(int maxSize = 100)
23+
public ClientCache(int maxSize)
2424
{
2525
_maxSize = maxSize;
2626
}
@@ -30,36 +30,58 @@ public ClientCache(int maxSize = 100)
3030
/// Updates the last-used timestamp on every hit.
3131
/// </summary>
3232
/// <typeparam name="T">The type of the client.</typeparam>
33-
/// <param name="clientId">An equality-comparable key representing the client configuration.</param>
33+
/// <param name="clientId">A key representing the client configuration.</param>
3434
/// <param name="createClient">A factory function to create the client if not cached.</param>
3535
/// <returns>The cached or newly created client instance.</returns>
36-
public T GetClient<T>(IEquatable<object> clientId, Func<T> createClient) where T : class
36+
public T GetClient<T>(object clientId, Func<T> createClient) where T : class
3737
{
38-
// If the client exists, update its timestamp.
39-
if (_clients.TryGetValue(clientId, out var cached))
38+
_lock.EnterReadLock();
39+
try
4040
{
41-
cached.LastUsed = Stopwatch.GetTimestamp();
42-
43-
if (cached.Client is T typedClient)
41+
// If the client exists, update its timestamp.
42+
if (_clients.TryGetValue(clientId, out var cached))
4443
{
45-
return typedClient;
46-
}
44+
cached.LastUsed = Stopwatch.GetTimestamp();
4745

48-
throw new InvalidOperationException($"The clientId is associated with a client of type '{cached.Client.GetType()}', not '{typeof(T)}'.");
46+
if (cached.Client is T typedClient)
47+
{
48+
return typedClient;
49+
}
50+
51+
throw new InvalidOperationException($"The clientId is associated with a client of type '{cached.Client.GetType()}', not '{typeof(T)}'.");
52+
}
53+
}
54+
finally
55+
{
56+
_lock.ExitReadLock();
4957
}
5058

51-
// Client not found in cache, create a new one.
59+
// Client not found, enter write lock
5260
_lock.EnterWriteLock();
5361
try
5462
{
63+
// Double-check inside write lock to avoid race condition
64+
if (_clients.TryGetValue(clientId, out var existing))
65+
{
66+
existing.LastUsed = Stopwatch.GetTimestamp();
67+
68+
if (existing.Client is T typedClient)
69+
{
70+
return typedClient;
71+
}
72+
73+
throw new InvalidOperationException($"The clientId is associated with a client of type '{existing.Client.GetType()}', not '{typeof(T)}'.");
74+
}
75+
76+
// Client not found in cache, create a new one.
5577
T created = createClient();
5678
_clients[clientId] = new ClientEntry(created, Stopwatch.GetTimestamp());
5779

58-
// After insertion, if cache exceeds the limit, perform cleanup.
5980
if (_clients.Count > _maxSize)
6081
{
6182
Cleanup();
6283
}
84+
6385
return created;
6486
}
6587
finally

sdk/core/System.ClientModel/src/Convenience/ConnectionProvider.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public abstract class ConnectionProvider
1818
/// Initializes a new instance of the ConnectionProvider class.
1919
/// </summary>
2020
/// <param name="maxCacheSize">The maximum number of subclients to store in the cache.</param>
21-
protected ConnectionProvider(int maxCacheSize = 100)
21+
protected ConnectionProvider(int maxCacheSize)
2222
{
2323
_subclients = new ClientCache(maxCacheSize);
2424
}

sdk/core/System.ClientModel/tests/Convenience/ClientCacheTests.cs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public class ClientCacheTests
1313
[Test]
1414
public void CacheShouldCleanupWhenExceedsLimit()
1515
{
16-
var clientCache = new ClientCache();
16+
var clientCache = new ClientCache(100);
1717

1818
// Add 110 clients to trigger the cleanup.
1919
for (int i = 0; i < 110; i++)
@@ -23,7 +23,7 @@ public void CacheShouldCleanupWhenExceedsLimit()
2323
}
2424

2525
var clientsField = typeof(ClientCache).GetField("_clients", BindingFlags.NonPublic | BindingFlags.Instance);
26-
var clients = clientsField?.GetValue(clientCache) as Dictionary<IEquatable<object>, ClientEntry>;
26+
var clients = clientsField?.GetValue(clientCache) as Dictionary<object, ClientEntry>;
2727

2828
Assert.IsNotNull(clients, "The _clients field is null.");
2929
Assert.AreEqual(100, clients!.Count, "Cache did not cleanup correctly.");
@@ -32,7 +32,7 @@ public void CacheShouldCleanupWhenExceedsLimit()
3232
[Test]
3333
public void CacheShouldNotCleanupWhenUnderLimit()
3434
{
35-
var clientCache = new ClientCache();
35+
var clientCache = new ClientCache(100);
3636

3737
// Add 50 clients, which is below the limit.
3838
for (int i = 0; i < 50; i++)
@@ -42,7 +42,7 @@ public void CacheShouldNotCleanupWhenUnderLimit()
4242
}
4343

4444
var clientsField = typeof(ClientCache).GetField("_clients", BindingFlags.NonPublic | BindingFlags.Instance);
45-
var clients = clientsField?.GetValue(clientCache) as Dictionary<IEquatable<object>, ClientEntry>;
45+
var clients = clientsField?.GetValue(clientCache) as Dictionary<object, ClientEntry>;
4646

4747
Assert.IsNotNull(clients, "The _clients field is null.");
4848
Assert.AreEqual(50, clients!.Count, "Cache should not have cleaned up when under the limit.");
@@ -51,7 +51,7 @@ public void CacheShouldNotCleanupWhenUnderLimit()
5151
[Test]
5252
public void CacheShouldCleanupOldestClients()
5353
{
54-
var clientCache = new ClientCache();
54+
var clientCache = new ClientCache(100);
5555

5656
// Add 110 clients to trigger cleanup (exceeds _maxClients = 100)
5757
for (int i = 0; i < 110; i++)
@@ -66,7 +66,7 @@ public void CacheShouldCleanupOldestClients()
6666
clientCache.GetClient(new DummyClientKey("client1"), () => new object());
6767

6868
var clientsField = typeof(ClientCache).GetField("_clients", BindingFlags.NonPublic | BindingFlags.Instance);
69-
var clients = clientsField?.GetValue(clientCache) as Dictionary<IEquatable<object>, ClientEntry>;
69+
var clients = clientsField?.GetValue(clientCache) as Dictionary<object, ClientEntry>;
7070

7171
Assert.IsNotNull(clients, "The _clients field is null.");
7272
Assert.AreEqual(100, clients!.Count, "Cache did not cleanup correctly.");
@@ -88,7 +88,7 @@ public void CacheShouldCleanupOldestClients()
8888
[Test]
8989
public void LRUShouldNotBeRemoved()
9090
{
91-
var clientCache = new ClientCache();
91+
var clientCache = new ClientCache(100);
9292

9393
for (int i = 0; i <= 100; i++)
9494
{
@@ -104,7 +104,7 @@ public void LRUShouldNotBeRemoved()
104104
clientCache.GetClient(new DummyClientKey("client102"), () => new object());
105105

106106
var clientsField = typeof(ClientCache).GetField("_clients", BindingFlags.NonPublic | BindingFlags.Instance);
107-
var clients = clientsField?.GetValue(clientCache) as Dictionary<IEquatable<object>, ClientEntry>;
107+
var clients = clientsField?.GetValue(clientCache) as Dictionary<object, ClientEntry>;
108108

109109
Assert.IsNotNull(clients, "The _clients field is null.");
110110
Assert.AreEqual(100, clients!.Count, "Cache did not cleanup correctly.");
@@ -121,7 +121,7 @@ public void LRUShouldNotBeRemoved()
121121
[Test]
122122
public void CacheShouldDisposeClientsWhenRemoved()
123123
{
124-
var clientCache = new ClientCache();
124+
var clientCache = new ClientCache(100);
125125

126126
// Create a disposable client
127127
var disposableClient = new DisposableClient();
@@ -135,7 +135,7 @@ public void CacheShouldDisposeClientsWhenRemoved()
135135
}
136136

137137
var clientsField = typeof(ClientCache).GetField("_clients", BindingFlags.NonPublic | BindingFlags.Instance);
138-
var clients = clientsField?.GetValue(clientCache) as Dictionary<IEquatable<object>, ClientEntry>;
138+
var clients = clientsField?.GetValue(clientCache) as Dictionary<object, ClientEntry>;
139139

140140
Assert.IsNotNull(clients, "The _clients field is null.");
141141
Assert.IsTrue(disposableClient.IsDisposed, "Disposable client was not disposed correctly.");
@@ -144,7 +144,7 @@ public void CacheShouldDisposeClientsWhenRemoved()
144144
[Test]
145145
public void CacheShouldHandleDifferentClientIdsSeparately()
146146
{
147-
var clientCache = new ClientCache();
147+
var clientCache = new ClientCache(100);
148148

149149
// Add clients with the same type but different IDs
150150
var client1 = new object();
@@ -154,7 +154,7 @@ public void CacheShouldHandleDifferentClientIdsSeparately()
154154
clientCache.GetClient(new DummyClientKey("client2"), () => client2);
155155

156156
var clientsField = typeof(ClientCache).GetField("_clients", BindingFlags.NonPublic | BindingFlags.Instance);
157-
var clients = clientsField?.GetValue(clientCache) as Dictionary<IEquatable<object>, ClientEntry>;
157+
var clients = clientsField?.GetValue(clientCache) as Dictionary<object, ClientEntry>;
158158

159159
Assert.IsNotNull(clients, "The _clients field is null.");
160160
Assert.IsTrue(clients!.ContainsKey(new DummyClientKey("client1")), "Client1 should be in the cache.");
@@ -192,7 +192,7 @@ public void ClientCacheShouldRespectMaxCacheSize()
192192
Assert.False(wasRecreated, "Client A was unexpectedly recreated");
193193

194194
var clientsField = typeof(ClientCache).GetField("_clients", BindingFlags.NonPublic | BindingFlags.Instance);
195-
var clients = clientsField?.GetValue(clientCache) as IDictionary<IEquatable<object>, ClientEntry>;
195+
var clients = clientsField?.GetValue(clientCache) as Dictionary<object, ClientEntry>;
196196

197197
Assert.IsNotNull(clients, "_clients dictionary should not be null");
198198

@@ -209,7 +209,7 @@ public void ClientCacheShouldRespectMaxCacheSize()
209209
[Test]
210210
public void CacheShouldHandleDifferentOptionsSeparately()
211211
{
212-
var clientCache = new ClientCache();
212+
var clientCache = new ClientCache(100);
213213

214214
// Define two different options as DummyClientKeys
215215
var options1 = new ClientPipelineOptions() { EnableDistributedTracing = true };
@@ -226,15 +226,15 @@ public void CacheShouldHandleDifferentOptionsSeparately()
226226
Assert.AreNotSame(client1, client2, "Clients should be distinct when options are different.");
227227

228228
var clientsField = typeof(ClientCache).GetField("_clients", BindingFlags.NonPublic | BindingFlags.Instance);
229-
var clients = clientsField?.GetValue(clientCache) as IDictionary<IEquatable<object>, ClientEntry>;
229+
var clients = clientsField?.GetValue(clientCache) as Dictionary<object, ClientEntry>;
230230

231231
// Assert that both clients are in the cache with the expected keys
232232
Assert.IsTrue(clients!.ContainsKey(new DummyClientKey("abc", options1)), "Client with options1 should be in the cache.");
233233
Assert.IsTrue(clients!.ContainsKey(new DummyClientKey("abc", options2)), "Client with options2 should be in the cache.");
234234
}
235235
}
236236

237-
internal record DummyClientKey(string Key, ClientPipelineOptions? options = null) : IEquatable<object>;
237+
internal record DummyClientKey(string Key, ClientPipelineOptions? options = null);
238238

239239
// Helper class to simulate a disposable client
240240
internal class DisposableClient : IDisposable

sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIExtensions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ private static EmbeddingClient CreateEmbeddingClient(this ConnectionProvider pro
8585
return embedding;
8686
}
8787

88-
private record AzureOpenAIClientKey() : IEquatable<object>;
88+
private record AzureOpenAIClientKey();
8989

90-
private record ChatClientKey(string? DeploymentName) : IEquatable<object>;
90+
private record ChatClientKey(string? DeploymentName);
9191

92-
private record EmbeddingClientKey(string? DeploymentName) : IEquatable<object>;
92+
private record EmbeddingClientKey(string? DeploymentName);
9393
}

0 commit comments

Comments
 (0)