Skip to content

Commit a8296c3

Browse files
Havretrayokota
andauthored
Remove shared mutex from CachedSchemaRegistryClient (#2449)
* Remove shared mutex from CachedSchemaRegistryClient * Fix failing test --------- Co-authored-by: Robert Yokota <[email protected]>
1 parent 3099db0 commit a8296c3

File tree

4 files changed

+132
-159
lines changed

4 files changed

+132
-159
lines changed

src/Confluent.SchemaRegistry/CachedSchemaRegistryClient.cs

Lines changed: 52 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
using System.Net.Http;
2727
using System.Collections.Concurrent;
2828
using System.Net;
29-
using System.Threading;
3029
using System.Security.Cryptography.X509Certificates;
3130
using Confluent.Kafka;
31+
using Confluent.Shared.CollectionUtils;
3232
using Microsoft.Extensions.Caching.Memory;
3333

3434

@@ -69,23 +69,21 @@ private record struct SchemaId(int Id, string Format);
6969
private IRestService restService;
7070
private int identityMapCapacity;
7171
private int latestCacheTtlSecs;
72-
private readonly ConcurrentDictionary<SchemaId, Schema> schemaById = new ConcurrentDictionary<SchemaId, Schema>();
72+
private readonly ConcurrentDictionary<SchemaId, Task<Schema>> schemaById = new ConcurrentDictionary<SchemaId, Task<Schema>>();
7373

74-
private readonly ConcurrentDictionary<string /*subject*/, ConcurrentDictionary<Schema, int>> idBySchemaBySubject =
75-
new ConcurrentDictionary<string, ConcurrentDictionary<Schema, int>>();
74+
private readonly ConcurrentDictionary<string /*subject*/, ConcurrentDictionary<Schema, Task<int>>> idBySchemaBySubject =
75+
new ConcurrentDictionary<string, ConcurrentDictionary<Schema, Task<int>>>();
7676

77-
private readonly ConcurrentDictionary<string /*subject*/, ConcurrentDictionary<int, RegisteredSchema>> schemaByVersionBySubject =
78-
new ConcurrentDictionary<string, ConcurrentDictionary<int, RegisteredSchema>>();
77+
private readonly ConcurrentDictionary<string /*subject*/, ConcurrentDictionary<int, Task<RegisteredSchema>>> schemaByVersionBySubject =
78+
new ConcurrentDictionary<string, ConcurrentDictionary<int, Task<RegisteredSchema>>>();
7979

80-
private readonly ConcurrentDictionary<string /*subject*/, ConcurrentDictionary<Schema, RegisteredSchema>> registeredSchemaBySchemaBySubject =
81-
new ConcurrentDictionary<string, ConcurrentDictionary<Schema, RegisteredSchema>>();
80+
private readonly ConcurrentDictionary<string /*subject*/, ConcurrentDictionary<Schema, Task<RegisteredSchema>>> registeredSchemaBySchemaBySubject =
81+
new ConcurrentDictionary<string, ConcurrentDictionary<Schema, Task<RegisteredSchema>>>();
8282

8383
private readonly MemoryCache latestVersionBySubject = new MemoryCache(new MemoryCacheOptions());
8484

8585
private readonly MemoryCache latestWithMetadataBySubject = new MemoryCache(new MemoryCacheOptions());
8686

87-
private readonly SemaphoreSlim cacheMutex = new SemaphoreSlim(1);
88-
8987
private SubjectNameStrategyDelegate keySubjectNameStrategy;
9088
private SubjectNameStrategyDelegate valueSubjectNameStrategy;
9189

@@ -607,45 +605,26 @@ public async Task<int> GetSchemaIdAsync(string subject, Schema schema, bool norm
607605
{
608606
if (idBySchemaBySubject.TryGetValue(subject, out var idBySchema))
609607
{
610-
if (idBySchema.TryGetValue(schema, out int schemaId))
608+
if (idBySchema.TryGetValue(schema, out var schemaId))
611609
{
612-
return schemaId;
610+
return await schemaId;
613611
}
614612
}
615613

616-
await cacheMutex.WaitAsync().ConfigureAwait(continueOnCapturedContext: false);
617-
try
618-
{
619-
if (!this.idBySchemaBySubject.TryGetValue(subject, out idBySchema))
620-
{
621-
idBySchema = new ConcurrentDictionary<Schema, int>();
622-
this.idBySchemaBySubject.TryAdd(subject, idBySchema);
623-
}
624-
625-
// TODO: The following could be optimized in the usual case where idBySchema only
626-
// contains very few elements and the schema string passed in is always the same
627-
// instance.
614+
CleanCacheIfFull();
628615

629-
if (!idBySchema.TryGetValue(schema, out int schemaId))
630-
{
631-
CleanCacheIfFull();
632-
633-
// throws SchemaRegistryException if schema is not known.
634-
var registeredSchema = await restService.LookupSchemaAsync(subject, schema, true, normalize)
635-
.ConfigureAwait(continueOnCapturedContext: false);
636-
idBySchema[schema] = registeredSchema.Id;
637-
638-
var format = GetSchemaFormat(schema.SchemaString);
639-
schemaById.TryAdd(new SchemaId(registeredSchema.Id, format), registeredSchema.Schema);
640-
schemaId = registeredSchema.Id;
641-
}
642-
643-
return schemaId;
644-
}
645-
finally
616+
idBySchema = idBySchemaBySubject.GetOrAdd(subject, _ => new ConcurrentDictionary<Schema, Task<int>>());
617+
return await idBySchema.GetOrAdd(schema, async _ =>
646618
{
647-
cacheMutex.Release();
648-
}
619+
var registeredSchema = await LookupSchemaAsync(subject, schema, true, normalize)
620+
.ConfigureAwait(continueOnCapturedContext: false);
621+
622+
// We already have the schema so we can add it to the cache.
623+
var format = GetSchemaFormat(registeredSchema.SchemaString);
624+
schemaById.TryAdd(new SchemaId(registeredSchema.Id, format), Task.FromResult(registeredSchema.Schema));
625+
626+
return registeredSchema.Id;
627+
}).ConfigureAwait(continueOnCapturedContext: false);
649628
}
650629

651630

@@ -656,41 +635,15 @@ public async Task<int> RegisterSchemaAsync(string subject, Schema schema, bool n
656635
{
657636
if (idBySchema.TryGetValue(schema, out var schemaId))
658637
{
659-
return schemaId;
638+
return await schemaId;
660639
}
661640
}
662641

663-
await cacheMutex.WaitAsync().ConfigureAwait(continueOnCapturedContext: false);
664-
try
665-
{
666-
if (!this.idBySchemaBySubject.TryGetValue(subject, out idBySchema))
667-
{
668-
idBySchema = new ConcurrentDictionary<Schema, int>();
669-
idBySchemaBySubject.TryAdd(subject, idBySchema);
670-
}
671-
672-
// TODO: This could be optimized in the usual case where idBySchema only
673-
// contains very few elements and the schema string passed in is always
674-
// the same instance.
675-
676-
if (!idBySchema.TryGetValue(schema, out int schemaId))
677-
{
678-
CleanCacheIfFull();
679-
680-
schemaId = await restService.RegisterSchemaAsync(subject, schema, normalize)
681-
.ConfigureAwait(continueOnCapturedContext: false);
682-
idBySchema[schema] = schemaId;
683-
}
684-
685-
return schemaId;
686-
}
687-
finally
688-
{
689-
cacheMutex.Release();
690-
}
642+
CleanCacheIfFull();
643+
idBySchema = idBySchemaBySubject.GetOrAdd(subject, _ => new ConcurrentDictionary<Schema, Task<int>>());
644+
return await idBySchema.GetOrAddAsync(schema, _ => restService.RegisterSchemaAsync(subject, schema, normalize)).ConfigureAwait(continueOnCapturedContext: false);
691645
}
692646

693-
694647
/// <inheritdoc/>
695648
public Task<int> RegisterSchemaAsync(string subject, string avroSchema, bool normalize = false)
696649
=> RegisterSchemaAsync(subject, new Schema(avroSchema, EmptyReferencesList, SchemaType.Avro), normalize);
@@ -712,31 +665,14 @@ public async Task<RegisteredSchema> LookupSchemaAsync(string subject, Schema sch
712665
{
713666
if (registeredSchemaBySchema.TryGetValue(schema, out var registeredSchema))
714667
{
715-
return registeredSchema;
668+
return await registeredSchema;
716669
}
717670
}
718671

719-
await cacheMutex.WaitAsync().ConfigureAwait(continueOnCapturedContext: false);
720-
try
721-
{
722-
if (!registeredSchemaBySchemaBySubject.TryGetValue(subject, out registeredSchemaBySchema))
723-
{
724-
CleanCacheIfFull();
725-
registeredSchemaBySchema = new ConcurrentDictionary<Schema, RegisteredSchema>();
726-
registeredSchemaBySchemaBySubject[subject] = registeredSchemaBySchema;
727-
}
728-
if (!registeredSchemaBySchema.TryGetValue(schema, out var registeredSchema))
729-
{
730-
registeredSchema = await restService.LookupSchemaAsync(subject, schema, ignoreDeletedSchemas, normalize).ConfigureAwait(continueOnCapturedContext: false);
731-
registeredSchemaBySchema[schema] = registeredSchema;
732-
}
733-
734-
return registeredSchema;
735-
}
736-
finally
737-
{
738-
cacheMutex.Release();
739-
}
672+
CleanCacheIfFull();
673+
674+
registeredSchemaBySchema = registeredSchemaBySchemaBySubject.GetOrAdd(subject, _ => new ConcurrentDictionary<Schema, Task<RegisteredSchema>>());
675+
return await registeredSchemaBySchema.GetOrAddAsync(schema, _ => restService.LookupSchemaAsync(subject, schema, ignoreDeletedSchemas, normalize)).ConfigureAwait(continueOnCapturedContext: false);
740676
}
741677

742678
/// <inheritdoc/>
@@ -745,25 +681,11 @@ public async Task<Schema> GetSchemaAsync(int id, string format = null)
745681
var schemaId = new SchemaId(id, format);
746682
if (schemaById.TryGetValue(schemaId, out var schema))
747683
{
748-
return schema;
684+
return await schema;
749685
}
750686

751-
await cacheMutex.WaitAsync().ConfigureAwait(continueOnCapturedContext: false);
752-
try
753-
{
754-
if (!this.schemaById.TryGetValue(schemaId, out schema))
755-
{
756-
CleanCacheIfFull();
757-
schema = await restService.GetSchemaAsync(id, format).ConfigureAwait(continueOnCapturedContext: false);
758-
schemaById.TryAdd(schemaId, schema);
759-
}
760-
761-
return schema;
762-
}
763-
finally
764-
{
765-
cacheMutex.Release();
766-
}
687+
CleanCacheIfFull();
688+
return await schemaById.GetOrAddAsync(schemaId, _ => restService.GetSchemaAsync(id, format)).ConfigureAwait(continueOnCapturedContext: false);
767689
}
768690

769691

@@ -773,64 +695,36 @@ public async Task<Schema> GetSchemaBySubjectAndIdAsync(string subject, int id, s
773695
var schemaId = new SchemaId(id, format);
774696
if (this.schemaById.TryGetValue(schemaId, out var schema))
775697
{
776-
return schema;
698+
return await schema;
777699
}
778700

779-
await cacheMutex.WaitAsync().ConfigureAwait(continueOnCapturedContext: false);
780-
try
781-
{
782-
if (!this.schemaById.TryGetValue(schemaId, out schema))
783-
{
784-
CleanCacheIfFull();
785-
schema = await restService.GetSchemaBySubjectAndIdAsync(subject, id, format)
786-
.ConfigureAwait(continueOnCapturedContext: false);
787-
schemaById.TryAdd(schemaId, schema);
788-
}
789-
790-
return schema;
791-
}
792-
finally
793-
{
794-
cacheMutex.Release();
795-
}
701+
return await schemaById.GetOrAddAsync(schemaId, _ => restService.GetSchemaBySubjectAndIdAsync(subject, id, format)).ConfigureAwait(continueOnCapturedContext: false);
796702
}
797703

798704

799705
/// <inheritdoc/>
800706
public async Task<RegisteredSchema> GetRegisteredSchemaAsync(string subject, int version, bool ignoreDeletedSchemas = true)
801707
{
802-
if (schemaByVersionBySubject.TryGetValue(subject, out var schemaByVersion) &&
803-
schemaByVersion.TryGetValue(version, out var schema))
804-
{
805-
return schema;
806-
}
807-
808-
await cacheMutex.WaitAsync().ConfigureAwait(continueOnCapturedContext: false);
809-
try
708+
if (schemaByVersionBySubject.TryGetValue(subject, out var schemaByVersion))
810709
{
811-
CleanCacheIfFull();
812-
813-
if (!schemaByVersionBySubject.TryGetValue(subject, out schemaByVersion))
710+
if (schemaByVersion.TryGetValue(version, out var schema))
814711
{
815-
schemaByVersion = new ConcurrentDictionary<int, RegisteredSchema>();
816-
schemaByVersionBySubject[subject] = schemaByVersion;
712+
return await schema;
817713
}
818-
819-
if (!schemaByVersion.TryGetValue(version, out schema))
820-
{
821-
schema = await restService.GetSchemaAsync(subject, version)
822-
.ConfigureAwait(continueOnCapturedContext: false);
823-
schemaByVersion[version] = schema;
824-
var format = GetSchemaFormat(schema.SchemaString);
825-
schemaById.TryAdd(new SchemaId(schema.Id, format), schema.Schema);
826-
}
827-
828-
return schema;
829-
}
830-
finally
831-
{
832-
cacheMutex.Release();
833714
}
715+
716+
CleanCacheIfFull();
717+
schemaByVersion = schemaByVersionBySubject.GetOrAdd(subject, _ => new ConcurrentDictionary<int, Task<RegisteredSchema>>());
718+
return await schemaByVersion.GetOrAddAsync(version, async _ =>
719+
{
720+
var schema = await restService.GetSchemaAsync(subject, version).ConfigureAwait(continueOnCapturedContext: false);
721+
722+
// We already have the schema so we can add it to the cache.
723+
var format = GetSchemaFormat(schema.SchemaString);
724+
schemaById.TryAdd(new SchemaId(schema.Id, format), Task.FromResult(schema.Schema));
725+
726+
return schema;
727+
}).ConfigureAwait(continueOnCapturedContext: false);
834728
}
835729

836730

src/Confluent.SchemaRegistry/Confluent.SchemaRegistry.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
<ItemGroup>
4747
<Compile Include="..\Shared\SetEqualityComparer.cs" Link="Includes/SetEqualityComparer.cs"/>
4848
<Compile Include="..\Shared\DictionaryEqualityComparer.cs" Link="Includes/DictionaryEqualityComparer.cs"/>
49+
<Compile Include="..\Shared\ConcurrentDictionaryExtensions.cs" Link="Includes/ConcurrentDictionaryExtensions.cs"/>
4950
</ItemGroup>
5051

5152
</Project>

src/Confluent.SchemaRegistry/Rest/DataContracts/RegisteredSchema.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ public Schema Schema
6262
{
6363
get
6464
{
65-
return new Schema(SchemaString, References, SchemaType, Metadata, RuleSet);
65+
return new Schema(SchemaString, References ?? new List<SchemaReference>(), SchemaType, Metadata, RuleSet);
6666
}
6767
}
6868

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright 2024 Confluent Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
// Refer to LICENSE for more information.
16+
17+
// NOTE: This implementation is based on the original gist by David Fowler:
18+
// https://gist.github.com/davidfowl/3dac8f7b3d141ae87abf770d5781feed
19+
20+
using System;
21+
using System.Collections.Concurrent;
22+
using System.Threading.Tasks;
23+
24+
namespace Confluent.Shared.CollectionUtils;
25+
26+
/// <summary>
27+
/// Extension methods for working with <see cref="ConcurrentDictionary{TKey, TValue}"/>
28+
/// where the values are asynchronous <see cref="Task{TResult}"/>.
29+
/// </summary>
30+
internal static class ConcurrentDictionaryExtensions
31+
{
32+
/// <summary>
33+
/// Asynchronously gets the value associated with the specified key, or adds a new value produced by the specified asynchronous factory function.
34+
///
35+
/// Ensures that the factory function is only invoked once for each key, even when accessed concurrently.
36+
/// If the factory throws, the entry is removed from the dictionary to allow future retries.
37+
/// </summary>
38+
/// <typeparam name="TKey">The type of keys in the dictionary.</typeparam>
39+
/// <typeparam name="TValue">The type of values returned in the <see cref="Task{TValue}"/>.</typeparam>
40+
/// <param name="dictionary">The dictionary to operate on. Values must be of type <see cref="Task{TValue}"/>.</param>
41+
/// <param name="key">The key whose value to get or add.</param>
42+
/// <param name="valueFactory">The asynchronous factory function to generate a value if the key does not exist.</param>
43+
/// <returns>A task representing the asynchronous operation, with the resulting value.</returns>
44+
public static async Task<TValue> GetOrAddAsync<TKey, TValue>(
45+
this ConcurrentDictionary<TKey, Task<TValue>> dictionary,
46+
TKey key,
47+
Func<TKey, Task<TValue>> valueFactory)
48+
{
49+
while (true)
50+
{
51+
if (dictionary.TryGetValue(key, out var task))
52+
{
53+
return await task.ConfigureAwait(continueOnCapturedContext: false);
54+
}
55+
56+
// This is the task that we'll return to all waiters. We'll complete it when the factory is complete
57+
var tcs = new TaskCompletionSource<TValue>(TaskCreationOptions.RunContinuationsAsynchronously);
58+
if (dictionary.TryAdd(key, tcs.Task))
59+
{
60+
try
61+
{
62+
var value = await valueFactory(key).ConfigureAwait(continueOnCapturedContext: false);
63+
tcs.TrySetResult(value);
64+
return await tcs.Task;
65+
}
66+
catch (Exception ex)
67+
{
68+
// Propagate the exception to all awaiting consumers.
69+
tcs.SetException(ex);
70+
71+
// Remove the entry to allow retries on failure.
72+
dictionary.TryRemove(key, out _);
73+
throw;
74+
}
75+
}
76+
}
77+
}
78+
}

0 commit comments

Comments
 (0)