Skip to content
60 changes: 22 additions & 38 deletions src/Confluent.Kafka/Impl/SafeKafkaHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
// Refer to LICENSE for more information.

using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -357,45 +358,35 @@ private IntPtr marshalHeaders(IReadOnlyList<IHeader> headers)
return headersPtr;
}

internal ErrorCode Produce(
internal unsafe ErrorCode Produce(
string topic,
byte[] val, int valOffset, int valLength,
byte[] key, int keyOffset, int keyLength,
ReadOnlyMemory<byte>? val,
ReadOnlyMemory<byte>? key,
int partition,
long timestamp,
IReadOnlyList<IHeader> headers,
IntPtr opaque)
{
var pValue = IntPtr.Zero;
var pKey = IntPtr.Zero;
MemoryHandle? valueHandle = null;
IntPtr valuePtr = IntPtr.Zero;
UIntPtr valueLength = UIntPtr.Zero;

var gchValue = default(GCHandle);
var gchKey = default(GCHandle);
MemoryHandle? keyHandle = null;
IntPtr keyPtr = IntPtr.Zero;
UIntPtr keyLength = UIntPtr.Zero;

if (val == null)
if (val != null)
{
if (valOffset != 0 || valLength != 0)
{
throw new ArgumentException("valOffset and valLength parameters must be 0 when producing null values.");
}
}
else
{
gchValue = GCHandle.Alloc(val, GCHandleType.Pinned);
pValue = Marshal.UnsafeAddrOfPinnedArrayElement(val, valOffset);
valueHandle = val.Value.Pin();
valuePtr = (IntPtr)valueHandle.Value.Pointer;
valueLength = (UIntPtr)val.Value.Length;
}

if (key == null)
if (key != null)
{
if (keyOffset != 0 || keyLength != 0)
{
throw new ArgumentException("keyOffset and keyLength parameters must be 0 when producing null key values.");
}
}
else
{
gchKey = GCHandle.Alloc(key, GCHandleType.Pinned);
pKey = Marshal.UnsafeAddrOfPinnedArrayElement(key, keyOffset);
keyHandle = key.Value.Pin();
keyPtr = (IntPtr)keyHandle.Value.Pointer;
keyLength = (UIntPtr)key.Value.Length;
}

IntPtr headersPtr = marshalHeaders(headers);
Expand All @@ -407,8 +398,8 @@ internal ErrorCode Produce(
topic,
partition,
(IntPtr)MsgFlags.MSG_F_COPY,
pValue, (UIntPtr)valLength,
pKey, (UIntPtr)keyLength,
valuePtr, valueLength,
keyPtr, keyLength,
timestamp,
headersPtr,
opaque);
Expand All @@ -433,15 +424,8 @@ internal ErrorCode Produce(
}
finally
{
if (val != null)
{
gchValue.Free();
}

if (key != null)
{
gchKey.Free();
}
valueHandle?.Dispose();
keyHandle?.Dispose();
}
}

Expand Down
88 changes: 56 additions & 32 deletions src/Confluent.Kafka/Producer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ internal class Config
public PartitionerDelegate defaultPartitioner;
}

private ISerializer<TKey> keySerializer;
private ISerializer<TValue> valueSerializer;
private Func<TKey, SerializationContext, ReadOnlyMemory<byte>?> serializeKey;
private Func<TValue, SerializationContext, ReadOnlyMemory<byte>?> serializeValue;
private IAsyncSerializer<TKey> asyncKeySerializer;
private IAsyncSerializer<TValue> asyncValueSerializer;

Expand All @@ -58,6 +58,14 @@ internal class Config
{ typeof(byte[]), Serializers.ByteArray }
};

private static readonly Dictionary<Type, object> memorySerializeFuncs = new Dictionary<Type, object>
{
[typeof(Memory<byte>)] = (Memory<byte> x, SerializationContext _) => (ReadOnlyMemory<byte>?)x,
[typeof(Memory<byte>?)] = (Memory<byte>? x, SerializationContext _) => (ReadOnlyMemory<byte>?)x,
[typeof(ReadOnlyMemory<byte>)] = (ReadOnlyMemory<byte> x, SerializationContext _) => (ReadOnlyMemory<byte>?)x,
[typeof(ReadOnlyMemory<byte>?)] = (ReadOnlyMemory<byte>? x, SerializationContext _) => x,
};

private int cancellationDelayMaxMs;
private bool disposeHasBeenCalled = false;
private object disposeHasBeenCalledLockObj = new object();
Expand Down Expand Up @@ -279,8 +287,8 @@ private void DeliveryReportCallbackImpl(IntPtr rk, IntPtr rkmessage, IntPtr opaq

private void ProduceImpl(
string topic,
byte[] val, int valOffset, int valLength,
byte[] key, int keyOffset, int keyLength,
ReadOnlyMemory<byte>? val,
ReadOnlyMemory<byte>? key,
Timestamp timestamp,
Partition partition,
IReadOnlyList<IHeader> headers,
Expand Down Expand Up @@ -308,8 +316,8 @@ private void ProduceImpl(

err = KafkaHandle.Produce(
topic,
val, valOffset, valLength,
key, keyOffset, keyLength,
val,
key,
partition.Value,
timestamp.UnixTimestampMs,
headers,
Expand All @@ -325,8 +333,8 @@ private void ProduceImpl(
{
err = KafkaHandle.Produce(
topic,
val, valOffset, valLength,
key, keyOffset, keyLength,
val,
key,
partition.Value,
timestamp.UnixTimestampMs,
headers,
Expand Down Expand Up @@ -508,20 +516,28 @@ private void InitializeSerializers(
// setup key serializer.
if (keySerializer == null && asyncKeySerializer == null)
{
if (!defaultSerializers.TryGetValue(typeof(TKey), out object serializer))
if (defaultSerializers.TryGetValue(typeof(TKey), out object serializer))
{
keySerializer = (ISerializer<TKey>)serializer;
this.serializeKey = (k, ctx) => keySerializer.Serialize(k, ctx)?.AsMemory();
}
else if (memorySerializeFuncs.TryGetValue(typeof(TKey), out object serialize))
{
this.serializeKey = (Func<TKey, SerializationContext, ReadOnlyMemory<byte>?>)serialize;
}
else
{
throw new ArgumentNullException(
$"Key serializer not specified and there is no default serializer defined for type {typeof(TKey).Name}.");
}
this.keySerializer = (ISerializer<TKey>)serializer;
}
else if (keySerializer == null && asyncKeySerializer != null)
{
this.asyncKeySerializer = asyncKeySerializer;
}
else if (keySerializer != null && asyncKeySerializer == null)
{
this.keySerializer = keySerializer;
this.serializeKey = (k, ctx) => keySerializer.Serialize(k, ctx)?.AsMemory();
}
else
{
Expand All @@ -531,20 +547,28 @@ private void InitializeSerializers(
// setup value serializer.
if (valueSerializer == null && asyncValueSerializer == null)
{
if (!defaultSerializers.TryGetValue(typeof(TValue), out object serializer))
if (defaultSerializers.TryGetValue(typeof(TValue), out object serializer))
{
valueSerializer = (ISerializer<TValue>)serializer;
this.serializeValue = (k, ctx) => valueSerializer.Serialize(k, ctx)?.AsMemory();
}
else if (memorySerializeFuncs.TryGetValue(typeof(TValue), out object serialize))
{
this.serializeValue = (Func<TValue, SerializationContext, ReadOnlyMemory<byte>?>)serialize;
}
else
{
throw new ArgumentNullException(
$"Value serializer not specified and there is no default serializer defined for type {typeof(TValue).Name}.");
}
this.valueSerializer = (ISerializer<TValue>)serializer;
}
else if (valueSerializer == null && asyncValueSerializer != null)
{
this.asyncValueSerializer = asyncValueSerializer;
}
else if (valueSerializer != null && asyncValueSerializer == null)
{
this.valueSerializer = valueSerializer;
this.serializeValue = (k, ctx) => valueSerializer.Serialize(k, ctx)?.AsMemory();
}
else
{
Expand Down Expand Up @@ -750,11 +774,11 @@ public async Task<DeliveryResult<TKey, TValue>> ProduceAsync(
{
Headers headers = message.Headers ?? new Headers();

byte[] keyBytes;
ReadOnlyMemory<byte>? keyBytes;
try
{
keyBytes = (keySerializer != null)
? keySerializer.Serialize(message.Key, new SerializationContext(MessageComponentType.Key, topicPartition.Topic, headers))
keyBytes = (serializeKey != null)
? serializeKey(message.Key, new SerializationContext(MessageComponentType.Key, topicPartition.Topic, headers))
: await asyncKeySerializer.SerializeAsync(message.Key, new SerializationContext(MessageComponentType.Key, topicPartition.Topic, headers)).ConfigureAwait(false);
}
catch (Exception ex)
Expand All @@ -769,11 +793,11 @@ public async Task<DeliveryResult<TKey, TValue>> ProduceAsync(
ex);
}

byte[] valBytes;
ReadOnlyMemory<byte>? valBytes;
try
{
valBytes = (valueSerializer != null)
? valueSerializer.Serialize(message.Value, new SerializationContext(MessageComponentType.Value, topicPartition.Topic, headers))
valBytes = (serializeValue != null)
? serializeValue(message.Value, new SerializationContext(MessageComponentType.Value, topicPartition.Topic, headers))
: await asyncValueSerializer.SerializeAsync(message.Value, new SerializationContext(MessageComponentType.Value, topicPartition.Topic, headers)).ConfigureAwait(false);
}
catch (Exception ex)
Expand Down Expand Up @@ -805,8 +829,8 @@ public async Task<DeliveryResult<TKey, TValue>> ProduceAsync(

ProduceImpl(
topicPartition.Topic,
valBytes, 0, valBytes == null ? 0 : valBytes.Length,
keyBytes, 0, keyBytes == null ? 0 : keyBytes.Length,
valBytes,
keyBytes,
message.Timestamp, topicPartition.Partition, headers.BackingList,
handler);

Expand All @@ -816,8 +840,8 @@ public async Task<DeliveryResult<TKey, TValue>> ProduceAsync(
{
ProduceImpl(
topicPartition.Topic,
valBytes, 0, valBytes == null ? 0 : valBytes.Length,
keyBytes, 0, keyBytes == null ? 0 : keyBytes.Length,
valBytes,
keyBytes,
message.Timestamp, topicPartition.Partition, headers.BackingList,
null);

Expand Down Expand Up @@ -873,11 +897,11 @@ public void Produce(

Headers headers = message.Headers ?? new Headers();

byte[] keyBytes;
ReadOnlyMemory<byte>? keyBytes;
try
{
keyBytes = (keySerializer != null)
? keySerializer.Serialize(message.Key, new SerializationContext(MessageComponentType.Key, topicPartition.Topic, headers))
keyBytes = (serializeKey != null)
? serializeKey(message.Key, new SerializationContext(MessageComponentType.Key, topicPartition.Topic, headers))
: throw new InvalidOperationException("Produce called with an IAsyncSerializer key serializer configured but an ISerializer is required.");
}
catch (Exception ex)
Expand All @@ -892,11 +916,11 @@ public void Produce(
ex);
}

byte[] valBytes;
ReadOnlyMemory<byte>? valBytes;
try
{
valBytes = (valueSerializer != null)
? valueSerializer.Serialize(message.Value, new SerializationContext(MessageComponentType.Value, topicPartition.Topic, headers))
valBytes = (serializeValue != null)
? serializeValue(message.Value, new SerializationContext(MessageComponentType.Value, topicPartition.Topic, headers))
: throw new InvalidOperationException("Produce called with an IAsyncSerializer value serializer configured but an ISerializer is required.");
}
catch (Exception ex)
Expand All @@ -915,8 +939,8 @@ public void Produce(
{
ProduceImpl(
topicPartition.Topic,
valBytes, 0, valBytes == null ? 0 : valBytes.Length,
keyBytes, 0, keyBytes == null ? 0 : keyBytes.Length,
valBytes,
keyBytes,
message.Timestamp, topicPartition.Partition,
headers.BackingList,
deliveryHandler == null
Expand Down
65 changes: 65 additions & 0 deletions test/Confluent.Kafka.IntegrationTests/Tests/Producer_Produce.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,71 @@ public void Producer_Produce(string bootstrapServers)

Assert.Equal(2, count);

// Memory<byte> case.

count = 0;
Action<DeliveryReport<Memory<byte>, ReadOnlyMemory<byte>>> dh3 = dr =>
{
Assert.Equal(ErrorCode.NoError, dr.Error.Code);
Assert.Equal(PersistenceStatus.Persisted, dr.Status);
Assert.Equal((Partition)0, dr.Partition);
Assert.Equal(singlePartitionTopic, dr.Topic);
Assert.True(dr.Offset >= 0);
Assert.Equal($"test key {count + 42}", Encoding.UTF8.GetString(dr.Message.Key.Span));
Assert.Equal($"test val {count + 42}", Encoding.UTF8.GetString(dr.Message.Value.Span));
Assert.Equal(TimestampType.CreateTime, dr.Message.Timestamp.Type);
Assert.True(Math.Abs((DateTime.UtcNow - dr.Message.Timestamp.UtcDateTime).TotalMinutes) < 1.0);
count += 1;
};

using (var producer = new TestProducerBuilder<Memory<byte>, ReadOnlyMemory<byte>>(producerConfig).Build())
{
producer.Produce(
new TopicPartition(singlePartitionTopic, 0),
new Message<Memory<byte>, ReadOnlyMemory<byte>> { Key = Encoding.UTF8.GetBytes("test key 42"), Value = Encoding.UTF8.GetBytes("test val 42") }, dh3);

producer.Produce(
singlePartitionTopic,
new Message<Memory<byte>, ReadOnlyMemory<byte>> { Key = Encoding.UTF8.GetBytes("test key 43"), Value = Encoding.UTF8.GetBytes("test val 43") }, dh3);

producer.Flush(TimeSpan.FromSeconds(10));
}

Assert.Equal(2, count);

// Memory<byte>? case.

count = 0;
Action<DeliveryReport<ReadOnlyMemory<byte>?, Memory<byte>?>> dh4 = dr =>
{
Assert.Equal(ErrorCode.NoError, dr.Error.Code);
Assert.Equal(PersistenceStatus.Persisted, dr.Status);
Assert.Equal((Partition)0, dr.Partition);
Assert.Equal(singlePartitionTopic, dr.Topic);
Assert.True(dr.Offset >= 0);
Assert.True(dr.Message.Key.HasValue);
Assert.Equal($"test key {count + 42}", Encoding.UTF8.GetString(dr.Message.Key.Value.Span));
Assert.True(dr.Message.Value.HasValue);
Assert.Equal($"test val {count + 42}", Encoding.UTF8.GetString(dr.Message.Value.Value.Span));
Assert.Equal(TimestampType.CreateTime, dr.Message.Timestamp.Type);
Assert.True(Math.Abs((DateTime.UtcNow - dr.Message.Timestamp.UtcDateTime).TotalMinutes) < 1.0);
count += 1;
};

using (var producer = new TestProducerBuilder<ReadOnlyMemory<byte>?, Memory<byte>?>(producerConfig).Build())
{
producer.Produce(
new TopicPartition(singlePartitionTopic, 0),
new Message<ReadOnlyMemory<byte>?, Memory<byte>?> { Key = Encoding.UTF8.GetBytes("test key 42"), Value = Encoding.UTF8.GetBytes("test val 42") }, dh4);

producer.Produce(
singlePartitionTopic,
new Message<ReadOnlyMemory<byte>?, Memory<byte>?> { Key = Encoding.UTF8.GetBytes("test key 43"), Value = Encoding.UTF8.GetBytes("test val 43") }, dh4);

producer.Flush(TimeSpan.FromSeconds(10));
}

Assert.Equal(2, count);

Assert.Equal(0, Library.HandleCount);
LogToFile("end Producer_Produce");
Expand Down
Loading