diff --git a/rpc/cmd/kafka_prod_receiver/kafka_prod_receiver.go b/rpc/cmd/kafka_prod_receiver/kafka_prod_receiver.go index 93a9e37..773200d 100644 --- a/rpc/cmd/kafka_prod_receiver/kafka_prod_receiver.go +++ b/rpc/cmd/kafka_prod_receiver/kafka_prod_receiver.go @@ -4,36 +4,36 @@ import ( "context" "encoding/json" "log" + "sync" // <--- added for mutex + "github.com/cybertec-postgresql/pgwatch3/rpc/proto" "github.com/destrex271/pgwatch3_rpc_server/sinks" - "github.com/destrex271/pgwatch3_rpc_server/sinks/pb" "github.com/segmentio/kafka-go" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) type KafkaProdReceiver struct { - conn_regisrty map[string]*kafka.Conn - uri string - auto_add bool + mu sync.RWMutex // <--- added mutex to protect connRegistry + connRegistry map[string]*kafka.Conn // <--- renamed from conn_regisrty + uri string + auto_add bool sinks.SyncMetricHandler } -// Handle Sync Metric Instructions func (r *KafkaProdReceiver) HandleSyncMetric() { for { req, ok := r.GetSyncChannelContent() if !ok { - // channel has been closed - return + return // channel closed } var err error switch req.Operation { - case pb.SyncOp_AddOp: - err = r.AddTopicIfNotExists(req.GetDBName()) - case pb.SyncOp_DeleteOp: - err = r.CloseConnectionForDB(req.GetDBName()) + case pb.SyncOperation_ADD: + err = r.AddTopicIfNotExists(req.GetDBName()) // <--- mutex used inside AddTopicIfNotExists + case pb.SyncOperation_DELETE: + err = r.CloseConnectionForDB(req.GetDBName()) // <--- mutex used inside CloseConnectionForDB } if err != nil { @@ -59,20 +59,21 @@ func NewKafkaProducer(host string, topics []string, partitions []int, auto_add b connRegistry[topic] = conn } kpr = &KafkaProdReceiver{ - conn_regisrty: connRegistry, + connRegistry: connRegistry, uri: host, SyncMetricHandler: sinks.NewSyncMetricHandler(1024), auto_add: auto_add, } - // Start sync Handler routine - go kpr.HandleSyncMetric() + go kpr.HandleSyncMetric() // start background sync goroutine return kpr, nil } func (r *KafkaProdReceiver) AddTopicIfNotExists(dbName string) error { - _, ok := r.conn_regisrty[dbName] - if ok { + r.mu.Lock() // <--- lock for write + defer r.mu.Unlock() // <--- unlock automatically at function exit + + if _, ok := r.connRegistry[dbName]; ok { return nil } @@ -81,13 +82,19 @@ func (r *KafkaProdReceiver) AddTopicIfNotExists(dbName string) error { return err } - r.conn_regisrty[dbName] = new_conn + r.connRegistry[dbName] = new_conn log.Println("[INFO]: Added Database " + dbName + " to sink") return nil } func (r *KafkaProdReceiver) CloseConnectionForDB(dbName string) error { - conn, ok := r.conn_regisrty[dbName] + r.mu.Lock() // <--- lock for write + conn, ok := r.connRegistry[dbName] + if ok { + delete(r.connRegistry, dbName) + } + r.mu.Unlock() // <--- unlock here + if !ok { return nil } @@ -97,41 +104,46 @@ func (r *KafkaProdReceiver) CloseConnectionForDB(dbName string) error { return err } - delete(r.conn_regisrty, dbName) log.Println("[INFO]: Deleted Database " + dbName + " from sink") return nil } func (r *KafkaProdReceiver) UpdateMeasurements(ctx context.Context, msg *pb.MeasurementEnvelope) (*pb.Reply, error) { - // Get connection for database topic DBName := msg.GetDBName() - conn, ok := r.conn_regisrty[DBName] + + // Reading the map without lock - safe because writes are rare + conn, ok := func() (*kafka.Conn, bool) { + r.mu.RLock() // optional read lock + defer r.mu.RUnlock() + conn, ok := r.connRegistry[DBName] + return conn, ok + }() + if !ok { log.Println("[WARNING]: Connection does not exist for database " + DBName) if r.auto_add { - log.Println("[INFO]: Adding database " + DBName + " since Auto Add is enabled. You can disable it by restarting the sink with autoadd option as false") - err := r.AddTopicIfNotExists(DBName) + log.Println("[INFO]: Adding database " + DBName + " since Auto Add is enabled") + err := r.AddTopicIfNotExists(DBName) // safe with write lock if err != nil { log.Println("[ERROR]: Unable to create new connection") return nil, err } - conn = r.conn_regisrty[DBName] + // read again after adding + r.mu.RLock() + conn = r.connRegistry[DBName] + r.mu.RUnlock() } else { - return nil, status.Error(codes.FailedPrecondition, "auto add not enabled. please restart the sink with autoadd=true") + return nil, status.Error(codes.FailedPrecondition, "auto add not enabled") } } - // Convert MeasurementEnvelope struct to json and write it as message in kafka json_data, err := json.Marshal(msg) if err != nil { log.Println("Unable to convert measurements data to json") return nil, status.Error(codes.InvalidArgument, err.Error()) } - _, err = conn.WriteMessages( - kafka.Message{Value: json_data}, - ) - + _, err = conn.WriteMessages(kafka.Message{Value: json_data}) if err != nil { log.Println("Failed to write messages!") return nil, err @@ -139,4 +151,4 @@ func (r *KafkaProdReceiver) UpdateMeasurements(ctx context.Context, msg *pb.Meas log.Println("[INFO]: Measurements Written to topic - ", DBName) return &pb.Reply{}, nil -} \ No newline at end of file +} diff --git a/rpc/cmd/kafka_prod_receiver/kafka_prod_receiver_test.go b/rpc/cmd/kafka_prod_receiver/kafka_prod_receiver_test.go index b553aa8..a22b5fc 100644 --- a/rpc/cmd/kafka_prod_receiver/kafka_prod_receiver_test.go +++ b/rpc/cmd/kafka_prod_receiver/kafka_prod_receiver_test.go @@ -87,14 +87,14 @@ func TestKafka_SyncMetricHandler(t *testing.T) { assert.NoError(t, err) time.Sleep(time.Second) // give some time handler - _, exists := kpr.conn_regisrty[req.GetDBName()] + _, exists := kpr.connRegistry[req.GetDBName()] assert.True(t, exists) - req.Operation = pb.SyncOp_DeleteOp + req.Operation = pb.SyncOperation_DELETE _, err = kpr.SyncMetric(ctx, req) assert.NoError(t, err) time.Sleep(time.Second) // give some time handler - _, exists = kpr.conn_regisrty[req.GetDBName()] + _, exists = kpr.connRegistry[req.GetDBName()] assert.False(t, exists) -} \ No newline at end of file +}