Skip to content

Natively support Memory<byte> #2311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
60 changes: 22 additions & 38 deletions src/Confluent.Kafka/Impl/SafeKafkaHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
// Refer to LICENSE for more information.

using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -357,45 +358,35 @@ private IntPtr marshalHeaders(IReadOnlyList<IHeader> headers)
return headersPtr;
}

internal ErrorCode Produce(
internal unsafe ErrorCode Produce(
string topic,
byte[] val, int valOffset, int valLength,
byte[] key, int keyOffset, int keyLength,
ReadOnlyMemory<byte>? val,
ReadOnlyMemory<byte>? key,
int partition,
long timestamp,
IReadOnlyList<IHeader> headers,
IntPtr opaque)
{
var pValue = IntPtr.Zero;
var pKey = IntPtr.Zero;
MemoryHandle? valueHandle = null;
IntPtr valuePtr = IntPtr.Zero;
UIntPtr valueLength = UIntPtr.Zero;

var gchValue = default(GCHandle);
var gchKey = default(GCHandle);
MemoryHandle? keyHandle = null;
IntPtr keyPtr = IntPtr.Zero;
UIntPtr keyLength = UIntPtr.Zero;

if (val == null)
if (val != null)
{
if (valOffset != 0 || valLength != 0)
{
throw new ArgumentException("valOffset and valLength parameters must be 0 when producing null values.");
}
}
else
{
gchValue = GCHandle.Alloc(val, GCHandleType.Pinned);
pValue = Marshal.UnsafeAddrOfPinnedArrayElement(val, valOffset);
valueHandle = val.Value.Pin();
valuePtr = (IntPtr)valueHandle.Value.Pointer;
valueLength = (UIntPtr)val.Value.Length;
}

if (key == null)
if (key != null)
{
if (keyOffset != 0 || keyLength != 0)
{
throw new ArgumentException("keyOffset and keyLength parameters must be 0 when producing null key values.");
}
}
else
{
gchKey = GCHandle.Alloc(key, GCHandleType.Pinned);
pKey = Marshal.UnsafeAddrOfPinnedArrayElement(key, keyOffset);
keyHandle = key.Value.Pin();
keyPtr = (IntPtr)keyHandle.Value.Pointer;
keyLength = (UIntPtr)key.Value.Length;
}

IntPtr headersPtr = marshalHeaders(headers);
Expand All @@ -407,8 +398,8 @@ internal ErrorCode Produce(
topic,
partition,
(IntPtr)MsgFlags.MSG_F_COPY,
pValue, (UIntPtr)valLength,
pKey, (UIntPtr)keyLength,
valuePtr, valueLength,
keyPtr, keyLength,
timestamp,
headersPtr,
opaque);
Expand All @@ -433,15 +424,8 @@ internal ErrorCode Produce(
}
finally
{
if (val != null)
{
gchValue.Free();
}

if (key != null)
{
gchKey.Free();
}
valueHandle?.Dispose();
keyHandle?.Dispose();
}
}

Expand Down
88 changes: 56 additions & 32 deletions src/Confluent.Kafka/Producer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ internal class Config
public PartitionerDelegate defaultPartitioner;
}

private ISerializer<TKey> keySerializer;
private ISerializer<TValue> valueSerializer;
private Func<TKey, SerializationContext, ReadOnlyMemory<byte>?> serializeKey;
private Func<TValue, SerializationContext, ReadOnlyMemory<byte>?> serializeValue;
private IAsyncSerializer<TKey> asyncKeySerializer;
private IAsyncSerializer<TValue> asyncValueSerializer;

Expand All @@ -58,6 +58,14 @@ internal class Config
{ typeof(byte[]), Serializers.ByteArray }
};

private static readonly Dictionary<Type, object> memorySerializeFuncs = new Dictionary<Type, object>
{
[typeof(Memory<byte>)] = (Memory<byte> x, SerializationContext _) => (ReadOnlyMemory<byte>?)x,
[typeof(Memory<byte>?)] = (Memory<byte>? x, SerializationContext _) => (ReadOnlyMemory<byte>?)x,
[typeof(ReadOnlyMemory<byte>)] = (ReadOnlyMemory<byte> x, SerializationContext _) => (ReadOnlyMemory<byte>?)x,
[typeof(ReadOnlyMemory<byte>?)] = (ReadOnlyMemory<byte>? x, SerializationContext _) => x,
};

private int cancellationDelayMaxMs;
private bool disposeHasBeenCalled = false;
private object disposeHasBeenCalledLockObj = new object();
Expand Down Expand Up @@ -279,8 +287,8 @@ private void DeliveryReportCallbackImpl(IntPtr rk, IntPtr rkmessage, IntPtr opaq

private void ProduceImpl(
string topic,
byte[] val, int valOffset, int valLength,
byte[] key, int keyOffset, int keyLength,
ReadOnlyMemory<byte>? val,
ReadOnlyMemory<byte>? key,
Timestamp timestamp,
Partition partition,
IReadOnlyList<IHeader> headers,
Expand Down Expand Up @@ -308,8 +316,8 @@ private void ProduceImpl(

err = KafkaHandle.Produce(
topic,
val, valOffset, valLength,
key, keyOffset, keyLength,
val,
key,
partition.Value,
timestamp.UnixTimestampMs,
headers,
Expand All @@ -325,8 +333,8 @@ private void ProduceImpl(
{
err = KafkaHandle.Produce(
topic,
val, valOffset, valLength,
key, keyOffset, keyLength,
val,
key,
partition.Value,
timestamp.UnixTimestampMs,
headers,
Expand Down Expand Up @@ -508,20 +516,28 @@ private void InitializeSerializers(
// setup key serializer.
if (keySerializer == null && asyncKeySerializer == null)
{
if (!defaultSerializers.TryGetValue(typeof(TKey), out object serializer))
if (defaultSerializers.TryGetValue(typeof(TKey), out object serializer))
{
keySerializer = (ISerializer<TKey>)serializer;
this.serializeKey = (k, ctx) => keySerializer.Serialize(k, ctx)?.AsMemory();
}
else if (memorySerializeFuncs.TryGetValue(typeof(TKey), out object serialize))
{
this.serializeKey = (Func<TKey, SerializationContext, ReadOnlyMemory<byte>?>)serialize;
}
else
{
throw new ArgumentNullException(
$"Key serializer not specified and there is no default serializer defined for type {typeof(TKey).Name}.");
}
this.keySerializer = (ISerializer<TKey>)serializer;
}
else if (keySerializer == null && asyncKeySerializer != null)
{
this.asyncKeySerializer = asyncKeySerializer;
}
else if (keySerializer != null && asyncKeySerializer == null)
{
this.keySerializer = keySerializer;
this.serializeKey = (k, ctx) => keySerializer.Serialize(k, ctx)?.AsMemory();
}
else
{
Expand All @@ -531,20 +547,28 @@ private void InitializeSerializers(
// setup value serializer.
if (valueSerializer == null && asyncValueSerializer == null)
{
if (!defaultSerializers.TryGetValue(typeof(TValue), out object serializer))
if (defaultSerializers.TryGetValue(typeof(TValue), out object serializer))
{
valueSerializer = (ISerializer<TValue>)serializer;
this.serializeValue = (k, ctx) => valueSerializer.Serialize(k, ctx)?.AsMemory();
}
else if (memorySerializeFuncs.TryGetValue(typeof(TValue), out object serialize))
{
this.serializeValue = (Func<TValue, SerializationContext, ReadOnlyMemory<byte>?>)serialize;
}
else
{
throw new ArgumentNullException(
$"Value serializer not specified and there is no default serializer defined for type {typeof(TValue).Name}.");
}
this.valueSerializer = (ISerializer<TValue>)serializer;
}
else if (valueSerializer == null && asyncValueSerializer != null)
{
this.asyncValueSerializer = asyncValueSerializer;
}
else if (valueSerializer != null && asyncValueSerializer == null)
{
this.valueSerializer = valueSerializer;
this.serializeValue = (k, ctx) => valueSerializer.Serialize(k, ctx)?.AsMemory();
}
else
{
Expand Down Expand Up @@ -750,11 +774,11 @@ public async Task<DeliveryResult<TKey, TValue>> ProduceAsync(
{
Headers headers = message.Headers ?? new Headers();

byte[] keyBytes;
ReadOnlyMemory<byte>? keyBytes;
try
{
keyBytes = (keySerializer != null)
? keySerializer.Serialize(message.Key, new SerializationContext(MessageComponentType.Key, topicPartition.Topic, headers))
keyBytes = (serializeKey != null)
? serializeKey(message.Key, new SerializationContext(MessageComponentType.Key, topicPartition.Topic, headers))
: await asyncKeySerializer.SerializeAsync(message.Key, new SerializationContext(MessageComponentType.Key, topicPartition.Topic, headers)).ConfigureAwait(false);
}
catch (Exception ex)
Expand All @@ -769,11 +793,11 @@ public async Task<DeliveryResult<TKey, TValue>> ProduceAsync(
ex);
}

byte[] valBytes;
ReadOnlyMemory<byte>? valBytes;
try
{
valBytes = (valueSerializer != null)
? valueSerializer.Serialize(message.Value, new SerializationContext(MessageComponentType.Value, topicPartition.Topic, headers))
valBytes = (serializeValue != null)
? serializeValue(message.Value, new SerializationContext(MessageComponentType.Value, topicPartition.Topic, headers))
: await asyncValueSerializer.SerializeAsync(message.Value, new SerializationContext(MessageComponentType.Value, topicPartition.Topic, headers)).ConfigureAwait(false);
}
catch (Exception ex)
Expand Down Expand Up @@ -805,8 +829,8 @@ public async Task<DeliveryResult<TKey, TValue>> ProduceAsync(

ProduceImpl(
topicPartition.Topic,
valBytes, 0, valBytes == null ? 0 : valBytes.Length,
keyBytes, 0, keyBytes == null ? 0 : keyBytes.Length,
valBytes,
keyBytes,
message.Timestamp, topicPartition.Partition, headers.BackingList,
handler);

Expand All @@ -816,8 +840,8 @@ public async Task<DeliveryResult<TKey, TValue>> ProduceAsync(
{
ProduceImpl(
topicPartition.Topic,
valBytes, 0, valBytes == null ? 0 : valBytes.Length,
keyBytes, 0, keyBytes == null ? 0 : keyBytes.Length,
valBytes,
keyBytes,
message.Timestamp, topicPartition.Partition, headers.BackingList,
null);

Expand Down Expand Up @@ -873,11 +897,11 @@ public void Produce(

Headers headers = message.Headers ?? new Headers();

byte[] keyBytes;
ReadOnlyMemory<byte>? keyBytes;
try
{
keyBytes = (keySerializer != null)
? keySerializer.Serialize(message.Key, new SerializationContext(MessageComponentType.Key, topicPartition.Topic, headers))
keyBytes = (serializeKey != null)
? serializeKey(message.Key, new SerializationContext(MessageComponentType.Key, topicPartition.Topic, headers))
: throw new InvalidOperationException("Produce called with an IAsyncSerializer key serializer configured but an ISerializer is required.");
}
catch (Exception ex)
Expand All @@ -892,11 +916,11 @@ public void Produce(
ex);
}

byte[] valBytes;
ReadOnlyMemory<byte>? valBytes;
try
{
valBytes = (valueSerializer != null)
? valueSerializer.Serialize(message.Value, new SerializationContext(MessageComponentType.Value, topicPartition.Topic, headers))
valBytes = (serializeValue != null)
? serializeValue(message.Value, new SerializationContext(MessageComponentType.Value, topicPartition.Topic, headers))
: throw new InvalidOperationException("Produce called with an IAsyncSerializer value serializer configured but an ISerializer is required.");
}
catch (Exception ex)
Expand All @@ -915,8 +939,8 @@ public void Produce(
{
ProduceImpl(
topicPartition.Topic,
valBytes, 0, valBytes == null ? 0 : valBytes.Length,
keyBytes, 0, keyBytes == null ? 0 : keyBytes.Length,
valBytes,
keyBytes,
message.Timestamp, topicPartition.Partition,
headers.BackingList,
deliveryHandler == null
Expand Down
65 changes: 65 additions & 0 deletions test/Confluent.Kafka.IntegrationTests/Tests/Producer_Produce.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,71 @@ public void Producer_Produce(string bootstrapServers)

Assert.Equal(2, count);

// Memory<byte> case.

count = 0;
Action<DeliveryReport<Memory<byte>, ReadOnlyMemory<byte>>> dh3 = dr =>
{
Assert.Equal(ErrorCode.NoError, dr.Error.Code);
Assert.Equal(PersistenceStatus.Persisted, dr.Status);
Assert.Equal((Partition)0, dr.Partition);
Assert.Equal(singlePartitionTopic, dr.Topic);
Assert.True(dr.Offset >= 0);
Assert.Equal($"test key {count + 42}", Encoding.UTF8.GetString(dr.Message.Key.Span));
Assert.Equal($"test val {count + 42}", Encoding.UTF8.GetString(dr.Message.Value.Span));
Assert.Equal(TimestampType.CreateTime, dr.Message.Timestamp.Type);
Assert.True(Math.Abs((DateTime.UtcNow - dr.Message.Timestamp.UtcDateTime).TotalMinutes) < 1.0);
count += 1;
};

using (var producer = new TestProducerBuilder<Memory<byte>, ReadOnlyMemory<byte>>(producerConfig).Build())
{
producer.Produce(
new TopicPartition(singlePartitionTopic, 0),
new Message<Memory<byte>, ReadOnlyMemory<byte>> { Key = Encoding.UTF8.GetBytes("test key 42"), Value = Encoding.UTF8.GetBytes("test val 42") }, dh3);

producer.Produce(
singlePartitionTopic,
new Message<Memory<byte>, ReadOnlyMemory<byte>> { Key = Encoding.UTF8.GetBytes("test key 43"), Value = Encoding.UTF8.GetBytes("test val 43") }, dh3);

producer.Flush(TimeSpan.FromSeconds(10));
}

Assert.Equal(2, count);

// Memory<byte>? case.

count = 0;
Action<DeliveryReport<ReadOnlyMemory<byte>?, Memory<byte>?>> dh4 = dr =>
{
Assert.Equal(ErrorCode.NoError, dr.Error.Code);
Assert.Equal(PersistenceStatus.Persisted, dr.Status);
Assert.Equal((Partition)0, dr.Partition);
Assert.Equal(singlePartitionTopic, dr.Topic);
Assert.True(dr.Offset >= 0);
Assert.True(dr.Message.Key.HasValue);
Assert.Equal($"test key {count + 42}", Encoding.UTF8.GetString(dr.Message.Key.Value.Span));
Assert.True(dr.Message.Value.HasValue);
Assert.Equal($"test val {count + 42}", Encoding.UTF8.GetString(dr.Message.Value.Value.Span));
Assert.Equal(TimestampType.CreateTime, dr.Message.Timestamp.Type);
Assert.True(Math.Abs((DateTime.UtcNow - dr.Message.Timestamp.UtcDateTime).TotalMinutes) < 1.0);
count += 1;
};

using (var producer = new TestProducerBuilder<ReadOnlyMemory<byte>?, Memory<byte>?>(producerConfig).Build())
{
producer.Produce(
new TopicPartition(singlePartitionTopic, 0),
new Message<ReadOnlyMemory<byte>?, Memory<byte>?> { Key = Encoding.UTF8.GetBytes("test key 42"), Value = Encoding.UTF8.GetBytes("test val 42") }, dh4);

producer.Produce(
singlePartitionTopic,
new Message<ReadOnlyMemory<byte>?, Memory<byte>?> { Key = Encoding.UTF8.GetBytes("test key 43"), Value = Encoding.UTF8.GetBytes("test val 43") }, dh4);

producer.Flush(TimeSpan.FromSeconds(10));
}

Assert.Equal(2, count);

Assert.Equal(0, Library.HandleCount);
LogToFile("end Producer_Produce");
Expand Down
Loading