Skip to content

Commit 6ed642c

Browse files
authored
Fix rule transformation of list properties (#2475)
* Fix transformation of list properties * Use Task.WhenAll
1 parent 8abfe1c commit 6ed642c

File tree

9 files changed

+377
-90
lines changed

9 files changed

+377
-90
lines changed

src/Confluent.SchemaRegistry.Serdes.Avro/AvroUtils.cs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,14 @@ public static async Task<object> Transform(RuleContext ctx, Avro.Schema schema,
5555
return await Transform(ctx, us[unionIndex], message, fieldTransform).ConfigureAwait(false);
5656
case Avro.Schema.Type.Array:
5757
ArraySchema a = (ArraySchema)schema;
58-
var arrayTasks = ((IList<object>)message)
59-
.Select(it => Transform(ctx, a.ItemSchema, it, fieldTransform))
60-
.ToList();
61-
object[] items = await Task.WhenAll(arrayTasks).ConfigureAwait(false);
62-
return items.ToList();
58+
var arrayTransformer = (int index, object elem) =>
59+
Transform(ctx, a.ItemSchema, elem, fieldTransform);
60+
return await Utils.TransformEnumerableAsync(message, arrayTransformer).ConfigureAwait(false);
6361
case Avro.Schema.Type.Map:
6462
MapSchema ms = (MapSchema)schema;
65-
var dictTasks = ((IDictionary<object, object>)message)
66-
.Select(it => Transform(ctx, ms.ValueSchema, it.Value, fieldTransform)
67-
.ContinueWith(t => new KeyValuePair<object, object>(it.Key, it.Value)))
68-
.ToList();
69-
KeyValuePair<object, object>[] entries = await Task.WhenAll(dictTasks).ConfigureAwait(false);
70-
return entries.ToDictionary(it => it.Key, it => it.Value);
63+
var mapTransformer = (object key, object value) =>
64+
Transform(ctx, ms.ValueSchema, value, fieldTransform);
65+
return await Utils.TransformDictionaryAsync(message, mapTransformer).ConfigureAwait(false);
7166
case Avro.Schema.Type.Record:
7267
RecordSchema rs = (RecordSchema)schema;
7368
foreach (Field f in rs.Fields)

src/Confluent.SchemaRegistry.Serdes.Json/JsonUtils.cs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,9 @@ public static async Task<object> Transform(RuleContext ctx, JsonSchema schema, s
7474
}
7575

7676
JsonSchema subschema = schema.Item;
77-
var tasks = ((IList<object>)message)
78-
.Select((it, index) => Transform(ctx, subschema, path + '[' + index + ']', it, fieldTransform))
79-
.ToList();
80-
object[] items = await Task.WhenAll(tasks).ConfigureAwait(false);
81-
return items.ToList();
77+
var transformer = (int index, object elem) =>
78+
Transform(ctx, subschema, path + '[' + index + ']', elem, fieldTransform);
79+
return await Utils.TransformEnumerableAsync(message, transformer).ConfigureAwait(false);
8280
}
8381
else if (schema.IsObject)
8482
{

src/Confluent.SchemaRegistry.Serdes.Protobuf/ProtobufUtils.cs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,9 @@ internal static async Task<object> Transform(RuleContext ctx, object desc, objec
101101
&& (message.GetType().GetGenericTypeDefinition() == typeof(List<>)
102102
|| message.GetType().GetGenericTypeDefinition() == typeof(IList<>))))
103103
{
104-
var tasks = ((IList<object>)message)
105-
.Select(it => Transform(ctx, desc, it, fieldTransform))
106-
.ToList();
107-
object[] items = await Task.WhenAll(tasks).ConfigureAwait(false);
108-
return items.ToList();
104+
var transformer = (int index, object elem) =>
105+
Transform(ctx, desc, elem, fieldTransform);
106+
return await Utils.TransformEnumerableAsync(message, transformer).ConfigureAwait(false);
109107
}
110108
else if (typeof(IDictionary).IsAssignableFrom(message.GetType())
111109
|| (message.GetType().IsGenericType

src/Confluent.SchemaRegistry.Serdes.Protobuf/Utils.cs

Lines changed: 0 additions & 69 deletions
This file was deleted.

src/Confluent.SchemaRegistry/Utils.cs

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
using System.Collections.Generic;
2020
using System.IO;
2121
using System.Linq;
22+
using System.Threading.Tasks;
2223

2324
#if NET8_0_OR_GREATER
2425
using System.Buffers.Text;
@@ -207,5 +208,140 @@ public static Guid GuidFromBigEndian(byte[] bigEndianBytes)
207208
// Create the Guid from the correctly formatted mixed-endian byte array
208209
return new Guid(mixedEndianBytes);
209210
}
211+
212+
/// <summary>
213+
/// Asynchronously transforms each element of an IEnumerable&lt;T&gt;,
214+
/// passing in its zero-based index and the element itself, using a
215+
/// Func&lt;int, object, Task&lt;object&gt;&gt;, and returns a List&lt;T&gt; of the transformed elements.
216+
/// </summary>
217+
/// <param name="sourceEnumerable">
218+
/// An object implementing IEnumerable&lt;T&gt; for some T.
219+
/// </param>
220+
/// <param name="indexedTransformer">
221+
/// A function that takes (index, element as object) and returns a Task whose Result
222+
/// is the new element (as object). The returned object must be castable to T.
223+
/// </param>
224+
/// <returns>
225+
/// A Task whose Result is a List&lt;T&gt; (boxed as object) containing all transformed elements.
226+
/// Await and then cast back to IEnumerable&lt;T&gt; to enumerate.
227+
/// </returns>
228+
public static async Task<object> TransformEnumerableAsync(
229+
object sourceEnumerable,
230+
Func<int, object, Task<object>> indexedTransformer)
231+
{
232+
if (sourceEnumerable == null)
233+
throw new ArgumentNullException(nameof(sourceEnumerable));
234+
if (indexedTransformer == null)
235+
throw new ArgumentNullException(nameof(indexedTransformer));
236+
237+
// 1. Find the IEnumerable<T> interface on the source object
238+
var srcType = sourceEnumerable.GetType();
239+
var enumInterface = srcType
240+
.GetInterfaces()
241+
.FirstOrDefault(i =>
242+
i.IsGenericType &&
243+
i.GetGenericTypeDefinition() == typeof(IEnumerable<>));
244+
245+
if (enumInterface == null)
246+
throw new ArgumentException("Source must implement IEnumerable<T>", nameof(sourceEnumerable));
247+
248+
// 2. Extract the element type T
249+
var elementType = enumInterface.GetGenericArguments()[0];
250+
251+
// 3. Build a List<T> at runtime
252+
var listType = typeof(List<>).MakeGenericType(elementType);
253+
var resultList = (IList)Activator.CreateInstance(listType);
254+
255+
// 4. Kick off all transforms in parallel
256+
var tasks = new List<Task<object>>();
257+
int index = 0;
258+
foreach (var item in (IEnumerable)sourceEnumerable)
259+
{
260+
tasks.Add(indexedTransformer(index, item));
261+
index++;
262+
}
263+
264+
// 5. Await them all at once
265+
var results = await Task.WhenAll(tasks).ConfigureAwait(false);
266+
267+
// 6. Populate the result list in original order
268+
foreach (var transformed in results)
269+
{
270+
resultList.Add(transformed);
271+
}
272+
273+
// 7. Return the List<T> as object
274+
return resultList;
275+
}
276+
277+
/// <summary>
278+
/// Asynchronously transforms each value of an IDictionary&lt;K,V&gt;,
279+
/// by invoking the provided Func&lt;object, object, Task&lt;object&gt;&gt;
280+
/// passing in the key and the original value
281+
/// and returns a new Dictionary&lt;K,V&gt; whose values are the awaited results.
282+
/// </summary>
283+
/// <param name="sourceDictionary">
284+
/// An object implementing IDictionary&lt;K,V&gt; for some K,V.
285+
/// </param>
286+
/// <param name="transformer">
287+
/// A function that takes (key as object, value as object) and returns a Task whose Result
288+
/// is the new value (as object). The returned object must be castable to V.
289+
/// </param>
290+
/// <returns>
291+
/// A Task whose Result is a Dictionary&lt;K,V&gt; containing all the transformed values.
292+
/// Await and then cast back to IDictionary&lt;K,V&gt; to enumerate.
293+
/// </returns>
294+
public static async Task<object> TransformDictionaryAsync(
295+
object sourceDictionary,
296+
Func<object, object, Task<object>> transformer)
297+
{
298+
if (sourceDictionary == null)
299+
throw new ArgumentNullException(nameof(sourceDictionary));
300+
if (transformer == null)
301+
throw new ArgumentNullException(nameof(transformer));
302+
303+
// 1. Find the IDictionary<K,V> interface on the source object
304+
var srcType = sourceDictionary.GetType();
305+
var dictInterface = srcType
306+
.GetInterfaces()
307+
.FirstOrDefault(i =>
308+
i.IsGenericType &&
309+
i.GetGenericTypeDefinition() == typeof(IDictionary<,>));
310+
311+
if (dictInterface == null)
312+
throw new ArgumentException("Source must implement IDictionary<K,V>", nameof(sourceDictionary));
313+
314+
// 2. Extract K and V
315+
var genericArgs = dictInterface.GetGenericArguments();
316+
var keyType = genericArgs[0];
317+
var valueType = genericArgs[1];
318+
319+
// 3. Create a Dictionary<K,V> at runtime
320+
var resultDictType = typeof(Dictionary<,>).MakeGenericType(keyType, valueType);
321+
var resultDict = (IDictionary)Activator.CreateInstance(resultDictType);
322+
323+
// 4. Enumerate the source via the non‐generic IDictionary interface
324+
var nonGenericDict = (IDictionary)sourceDictionary;
325+
var keys = new List<object>();
326+
var tasks = new List<Task<object>>();
327+
328+
foreach (DictionaryEntry entry in nonGenericDict)
329+
{
330+
keys.Add(entry.Key);
331+
tasks.Add(transformer(entry.Key, entry.Value));
332+
}
333+
334+
// 5. Await all transformations
335+
var transformedValues = await Task.WhenAll(tasks).ConfigureAwait(false);
336+
337+
// 6. Reconstruct the new dictionary in original order
338+
for (int i = 0; i < keys.Count; i++)
339+
{
340+
resultDict.Add(keys[i], transformedValues[i]);
341+
}
342+
343+
// 7. Return boxed Dictionary<K,V>
344+
return resultDict;
345+
}
210346
}
211347
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"namespace": "confluent.io.examples.serialization.avro",
3+
"name": "Complex",
4+
"type": "record",
5+
"fields": [
6+
{
7+
"name": "arrayField",
8+
"type": {"type": "array", "items": "string"}
9+
},
10+
{
11+
"name": "mapField",
12+
"type": {"type": "map", "values": "string"}
13+
},
14+
{
15+
"name": "unionField",
16+
"type": ["null", "string"], "confluent:tags": ["PII"]
17+
}
18+
]
19+
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// ------------------------------------------------------------------------------
2+
// <auto-generated>
3+
// Generated by avrogen, version 1.12.0+8c27801dc8d42ccc00997f25c0b8f45f8d4a233e
4+
// Changes to this file may cause incorrect behavior and will be lost if code
5+
// is regenerated
6+
// </auto-generated>
7+
// ------------------------------------------------------------------------------
8+
namespace Confluent.Kafka.Examples.AvroSpecific
9+
{
10+
using System;
11+
using System.Collections.Generic;
12+
using System.Text;
13+
using global::Avro;
14+
using global::Avro.Specific;
15+
16+
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("avrogen", "1.12.0+8c27801dc8d42ccc00997f25c0b8f45f8d4a233e")]
17+
public partial class Complex : global::Avro.Specific.ISpecificRecord
18+
{
19+
public static global::Avro.Schema _SCHEMA = global::Avro.Schema.Parse(@"{""type"":""record"",""name"":""Complex"",""namespace"":""Confluent.Kafka.Examples.AvroSpecific"",""fields"":[{""name"":""arrayField"",""type"":{""type"":""array"",""items"":""string""}},{""name"":""mapField"",""type"":{""type"":""map"",""values"":""string""}},{""name"":""unionField"",""type"":[""null"",""string""],""confluent:tags"":[""PII""]}]}");
20+
private IList<System.String> _arrayField;
21+
private IDictionary<string,System.String> _mapField;
22+
private string _unionField;
23+
public virtual global::Avro.Schema Schema
24+
{
25+
get
26+
{
27+
return Complex._SCHEMA;
28+
}
29+
}
30+
public IList<System.String> arrayField
31+
{
32+
get
33+
{
34+
return this._arrayField;
35+
}
36+
set
37+
{
38+
this._arrayField = value;
39+
}
40+
}
41+
public IDictionary<string,System.String> mapField
42+
{
43+
get
44+
{
45+
return this._mapField;
46+
}
47+
set
48+
{
49+
this._mapField = value;
50+
}
51+
}
52+
public string unionField
53+
{
54+
get
55+
{
56+
return this._unionField;
57+
}
58+
set
59+
{
60+
this._unionField = value;
61+
}
62+
}
63+
public virtual object Get(int fieldPos)
64+
{
65+
switch (fieldPos)
66+
{
67+
case 0: return this.arrayField;
68+
case 1: return this.mapField;
69+
case 2: return this.unionField;
70+
default: throw new global::Avro.AvroRuntimeException("Bad index " + fieldPos + " in Get()");
71+
};
72+
}
73+
public virtual void Put(int fieldPos, object fieldValue)
74+
{
75+
switch (fieldPos)
76+
{
77+
case 0: this.arrayField = (IList<System.String>)fieldValue; break;
78+
case 1: this.mapField = (IDictionary<string,System.String>)fieldValue; break;
79+
case 2: this.unionField = (System.String)fieldValue; break;
80+
default: throw new global::Avro.AvroRuntimeException("Bad index " + fieldPos + " in Put()");
81+
};
82+
}
83+
}
84+
}

0 commit comments

Comments
 (0)