diff --git a/pulsar/message_chunking_test.go b/pulsar/message_chunking_test.go index c2e1113df3..b85e914f8c 100644 --- a/pulsar/message_chunking_test.go +++ b/pulsar/message_chunking_test.go @@ -544,7 +544,7 @@ func sendSingleChunk(p Producer, uuid string, chunkID int, totalChunks int) { Payload: []byte(fmt.Sprintf("chunk-%s-%d|", uuid, chunkID)), } wholePayload := msg.Payload - producerImpl := p.(*producer).producers[0].(*partitionProducer) + producerImpl := p.(*producer).getProducer(0).(*partitionProducer) mm := producerImpl.genMetadata(msg, len(wholePayload), time.Now()) mm.Uuid = proto.String(uuid) mm.NumChunksFromMsg = proto.Int32(int32(totalChunks)) diff --git a/pulsar/producer_impl.go b/pulsar/producer_impl.go index 8e970d28fe..a464f14c43 100644 --- a/pulsar/producer_impl.go +++ b/pulsar/producer_impl.go @@ -23,7 +23,6 @@ import ( "sync" "sync/atomic" "time" - "unsafe" "github.com/apache/pulsar-client-go/pulsar/crypto" "github.com/apache/pulsar-client-go/pulsar/internal" @@ -48,13 +47,10 @@ const ( ) type producer struct { - sync.RWMutex client *client options *ProducerOptions topic string - producers []Producer - producersPtr unsafe.Pointer - numPartitions uint32 + producers atomic.Value messageRouter func(*ProducerMessage, TopicMetadata) int closeOnce sync.Once stopDiscovery func() @@ -195,10 +191,7 @@ func (p *producer) internalCreatePartitionsProducers() error { oldNumPartitions := 0 newNumPartitions := len(partitions) - p.Lock() - defer p.Unlock() - - oldProducers := p.producers + oldProducers := p.getProducers() oldNumPartitions = len(oldProducers) if oldProducers != nil { @@ -213,14 +206,14 @@ func (p *producer) internalCreatePartitionsProducers() error { } - p.producers = make([]Producer, newNumPartitions) + producers := make([]Producer, newNumPartitions) // When for some reason (eg: forced deletion of sub partition) causes oldNumPartitions> newNumPartitions, // we need to rebuild the cache of new producers, otherwise the array will be out of bounds. if oldProducers != nil && oldNumPartitions < newNumPartitions { // Copy over the existing consumer instances for i := 0; i < oldNumPartitions; i++ { - p.producers[i] = oldProducers[i] + producers[i] = oldProducers[i] } } @@ -251,20 +244,23 @@ func (p *producer) internalCreatePartitionsProducers() error { }(partitionIdx, partition) } + var newProducers []Producer + for i := 0; i < partitionsToAdd; i++ { pe, ok := <-c if ok { if pe.err != nil { err = pe.err } else { - p.producers[pe.partition] = pe.prod + producers[pe.partition] = pe.prod + newProducers = append(newProducers, pe.prod) } } } if err != nil { // Since there were some failures, cleanup all the partitions that succeeded in creating the producers - for _, producer := range p.producers { + for _, producer := range newProducers { if producer != nil { producer.Close() } @@ -277,8 +273,7 @@ func (p *producer) internalCreatePartitionsProducers() error { } else { p.metrics.ProducersPartitions.Add(float64(partitionsToAdd)) } - atomic.StorePointer(&p.producersPtr, unsafe.Pointer(&p.producers)) - atomic.StoreUint32(&p.numPartitions, uint32(len(p.producers))) + p.producers.Store(producers) return nil } @@ -287,14 +282,11 @@ func (p *producer) Topic() string { } func (p *producer) Name() string { - p.RLock() - defer p.RUnlock() - - return p.producers[0].Name() + return p.getProducer(0).Name() } func (p *producer) NumPartitions() uint32 { - return atomic.LoadUint32(&p.numPartitions) + return uint32(len(p.getProducers())) } func (p *producer) Send(ctx context.Context, msg *ProducerMessage) (MessageID, error) { @@ -306,11 +298,11 @@ func (p *producer) SendAsync(ctx context.Context, msg *ProducerMessage, p.getPartition(msg).SendAsync(ctx, msg, callback) } -func (p *producer) getPartition(msg *ProducerMessage) Producer { - // Since partitions can only increase, it's ok if the producers list - // is updated in between. The numPartition is updated only after the list. - partition := p.messageRouter(msg, p) - producers := *(*[]Producer)(atomic.LoadPointer(&p.producersPtr)) +func (p *producer) getProducer(partition int) Producer { + producers := p.getProducers() + if len(producers) == 0 { + panic("producer has not been initialized properly") + } if partition >= len(producers) { // We read the old producers list while the count was already // updated @@ -319,12 +311,23 @@ func (p *producer) getPartition(msg *ProducerMessage) Producer { return producers[partition] } -func (p *producer) LastSequenceID() int64 { - p.RLock() - defer p.RUnlock() +func (p *producer) getProducers() []Producer { + if producers := p.producers.Load(); producers != nil { + return producers.([]Producer) + } + return []Producer{} +} + +func (p *producer) getPartition(msg *ProducerMessage) Producer { + // Since partitions can only increase, it's ok if the producers list + // is updated in between. The numPartition is updated only after the list. + partition := p.messageRouter(msg, p) + return p.getProducer(partition) +} +func (p *producer) LastSequenceID() int64 { var maxSeq int64 = -1 - for _, pp := range p.producers { + for _, pp := range p.getProducers() { s := pp.LastSequenceID() if s > maxSeq { maxSeq = s @@ -338,10 +341,7 @@ func (p *producer) Flush() error { } func (p *producer) FlushWithCtx(ctx context.Context) error { - p.RLock() - defer p.RUnlock() - - for _, pp := range p.producers { + for _, pp := range p.getProducers() { if err := pp.FlushWithCtx(ctx); err != nil { return err } @@ -354,14 +354,12 @@ func (p *producer) Close() { p.closeOnce.Do(func() { p.stopDiscovery() - p.Lock() - defer p.Unlock() - - for _, pp := range p.producers { + producers := p.getProducers() + for _, pp := range producers { pp.Close() } p.client.handlers.Del(p) - p.metrics.ProducersPartitions.Sub(float64(len(p.producers))) + p.metrics.ProducersPartitions.Sub(float64(len(producers))) p.metrics.ProducersClosed.Inc() }) } diff --git a/pulsar/producer_test.go b/pulsar/producer_test.go index 00876070f7..308059422a 100644 --- a/pulsar/producer_test.go +++ b/pulsar/producer_test.go @@ -30,6 +30,9 @@ import ( "testing" "time" + "github.com/apache/pulsar-client-go/pulsaradmin" + "github.com/apache/pulsar-client-go/pulsaradmin/pkg/admin/config" + "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" @@ -1382,7 +1385,7 @@ func TestProducerWithBackoffPolicy(t *testing.T) { assert.Nil(t, err) defer _producer.Close() - partitionProducerImp := _producer.(*producer).producers[0].(*partitionProducer) + partitionProducerImp := _producer.(*producer).getProducer(0).(*partitionProducer) // 1 s startTime := time.Now() partitionProducerImp.reconnectToBroker(nil) @@ -2477,7 +2480,7 @@ func TestFailPendingMessageWithClose(t *testing.T) { } }) } - partitionProducerImp := testProducer.(*producer).producers[0].(*partitionProducer) + partitionProducerImp := testProducer.(*producer).getProducer(0).(*partitionProducer) partitionProducerImp.pendingQueue.Put(&pendingItem{ buffer: buffersPool.GetBuffer(0), }) @@ -2597,7 +2600,7 @@ func TestDisableReplication(t *testing.T) { writtenBuffers: &writtenBuffers, } - partitionProducerImp := testProducer.(*producer).producers[0].(*partitionProducer) + partitionProducerImp := testProducer.(*producer).getProducer(0).(*partitionProducer) partitionProducerImp.pendingQueue = pqw ID, err := testProducer.Send(context.Background(), &ProducerMessage{ @@ -2718,7 +2721,7 @@ func TestSelectConnectionForSameProducer(t *testing.T) { assert.NoError(t, err) defer _producer.Close() - partitionProducerImp := _producer.(*producer).producers[0].(*partitionProducer) + partitionProducerImp := _producer.(*producer).getProducer(0).(*partitionProducer) conn := partitionProducerImp._getConn() for i := 0; i < 5; i++ { @@ -2762,7 +2765,7 @@ func TestSendBufferRetainWhenConnectionStuck(t *testing.T) { Topic: topicName, }) assert.NoError(t, err) - pp := p.(*producer).producers[0].(*partitionProducer) + pp := p.(*producer).getProducer(0).(*partitionProducer) // Create a mock connection that tracks written buffers conn := &mockConn{ @@ -2898,3 +2901,54 @@ func testSendAsyncCouldTimeoutWhileReconnecting(t *testing.T, isDisableBatching } close(finalErr) } + +type mockRPCClient struct { + internal.RPCClient +} + +func (m *mockRPCClient) RequestOnCnx(_ internal.Connection, _ uint64, _ pb.BaseCommand_Type, + _ proto.Message) (*internal.RPCResult, error) { + return nil, fmt.Errorf("expected error") +} + +func TestPartitionUpdateFailed(t *testing.T) { + topicName := newTopicName() + + admin, err := pulsaradmin.NewClient(&config.Config{ + WebServiceURL: adminURL, + }) + require.NoError(t, err) + + tn, err := utils.GetTopicName(topicName) + require.NoError(t, err) + require.NoError(t, admin.Topics().Create(*tn, 1)) + + c, err := NewClient(ClientOptions{ + URL: serviceURL, + }) + require.NoError(t, err) + p, err := c.CreateProducer(ProducerOptions{ + Topic: topicName, + PartitionsAutoDiscoveryInterval: time.Second * 1, + }) + require.NoError(t, err) + _, err = p.Send(context.Background(), &ProducerMessage{ + Payload: []byte("test"), + }) + require.NoError(t, err) + c.(*client).rpcClient = &mockRPCClient{ + RPCClient: c.(*client).rpcClient, + } + + require.NoError(t, admin.Topics().Update(*tn, 2)) + + // Assert that partition update failed won't affect the existing producers + for i := 0; i < 5; i++ { + _, err = p.Send(context.Background(), &ProducerMessage{ + Payload: []byte("test"), + }) + require.NoError(t, err) + + time.Sleep(time.Second * 1) + } +}