From cd41a35bb1545f83da8ee29eb6296bfd954243bc Mon Sep 17 00:00:00 2001 From: Donal Hurley Date: Mon, 17 Nov 2025 14:22:27 +0000 Subject: [PATCH] Ensure in progress requests are processed before resetting the gRPC connection with the management plane --- internal/command/command_plugin.go | 20 ++++--- internal/command/command_plugin_test.go | 9 +++- internal/command/command_service.go | 69 ++++++++++++++++++++----- internal/config/config.go | 7 +++ internal/config/defaults.go | 13 ++--- internal/config/flags.go | 1 + internal/config/types.go | 13 ++--- 7 files changed, 97 insertions(+), 35 deletions(-) diff --git a/internal/command/command_plugin.go b/internal/command/command_plugin.go index 616c48eb8..74f050568 100644 --- a/internal/command/command_plugin.go +++ b/internal/command/command_plugin.go @@ -118,7 +118,9 @@ func (cp *CommandPlugin) Process(ctx context.Context, msg *bus.Message) { if logger.ServerType(ctxWithMetadata) == cp.commandServerType.String() { switch msg.Topic { case bus.ConnectionResetTopic: - cp.processConnectionReset(ctxWithMetadata, msg) + // Running as a separate go routine so that the command plugin can continue to process data plane responses + // while the connection reset is in progress + go cp.processConnectionReset(ctxWithMetadata, msg) case bus.ResourceUpdateTopic: cp.processResourceUpdate(ctxWithMetadata, msg) case bus.InstanceHealthTopic: @@ -232,11 +234,19 @@ func (cp *CommandPlugin) processConnectionReset(ctx context.Context, msg *bus.Me slog.DebugContext(ctx, "Command plugin received connection reset message") if newConnection, ok := msg.Data.(grpc.GrpcConnectionInterface); ok { - slog.DebugContext(ctx, "Canceling Subscribe after connection reset") ctxWithMetadata := cp.config.NewContextWithLabels(ctx) cp.subscribeMutex.Lock() defer cp.subscribeMutex.Unlock() + // Update the command service with the new client first + err := cp.commandService.UpdateClient(ctxWithMetadata, newConnection.CommandServiceClient()) + if err != nil { + slog.ErrorContext(ctx, "Failed to reset connection", "error", err) + return + } + + // Once the command service is updated, we close the old connection + slog.DebugContext(ctx, "Canceling Subscribe after connection reset") if cp.subscribeCancel != nil { cp.subscribeCancel() slog.DebugContext(ctxWithMetadata, "Successfully canceled subscribe after connection reset") @@ -248,12 +258,6 @@ func (cp *CommandPlugin) processConnectionReset(ctx context.Context, msg *bus.Me } cp.conn = newConnection - err := cp.commandService.UpdateClient(ctx, cp.conn.CommandServiceClient()) - if err != nil { - slog.ErrorContext(ctx, "Failed to reset connection", "error", err) - return - } - slog.DebugContext(ctxWithMetadata, "Starting new subscribe after connection reset") subscribeCtx, cp.subscribeCancel = context.WithCancel(ctxWithMetadata) go cp.commandService.Subscribe(subscribeCtx) diff --git a/internal/command/command_plugin_test.go b/internal/command/command_plugin_test.go index c51f3c579..a6dbe76ab 100644 --- a/internal/command/command_plugin_test.go +++ b/internal/command/command_plugin_test.go @@ -150,7 +150,14 @@ func TestCommandPlugin_Process(t *testing.T) { Topic: bus.ConnectionResetTopic, Data: commandPlugin.conn, }) - require.Equal(t, 1, fakeCommandService.UpdateClientCallCount()) + + // Separate goroutine is executed so need to wait for it to complete + assert.Eventually( + t, + func() bool { return fakeCommandService.UpdateClientCallCount() == 1 }, + 2*time.Second, + 10*time.Millisecond, + ) } func TestCommandPlugin_monitorSubscribeChannel(t *testing.T) { diff --git a/internal/command/command_service.go b/internal/command/command_service.go index a242b92a2..74a7ecd31 100644 --- a/internal/command/command_service.go +++ b/internal/command/command_service.go @@ -12,6 +12,7 @@ import ( "log/slog" "sync" "sync/atomic" + "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -33,6 +34,7 @@ var _ commandService = (*CommandService)(nil) const ( createConnectionMaxElapsedTime = 0 + timeToWaitBetweenChecks = 5 * time.Second ) type ( @@ -41,8 +43,10 @@ type ( subscribeClient mpi.CommandService_SubscribeClient agentConfig *config.Config isConnected *atomic.Bool + connectionResetInProgress *atomic.Bool subscribeChannel chan *mpi.ManagementPlaneRequest configApplyRequestQueue map[string][]*mpi.ManagementPlaneRequest // key is the instance ID + requestsInProgress map[string]*mpi.ManagementPlaneRequest // key is the correlation ID resource *mpi.Resource subscribeClientMutex sync.Mutex configApplyRequestQueueMutex sync.Mutex @@ -55,19 +59,16 @@ func NewCommandService( agentConfig *config.Config, subscribeChannel chan *mpi.ManagementPlaneRequest, ) *CommandService { - isConnected := &atomic.Bool{} - isConnected.Store(false) - - commandService := &CommandService{ - commandServiceClient: commandServiceClient, - agentConfig: agentConfig, - isConnected: isConnected, - subscribeChannel: subscribeChannel, - configApplyRequestQueue: make(map[string][]*mpi.ManagementPlaneRequest), - resource: &mpi.Resource{}, + return &CommandService{ + commandServiceClient: commandServiceClient, + agentConfig: agentConfig, + isConnected: &atomic.Bool{}, + connectionResetInProgress: &atomic.Bool{}, + subscribeChannel: subscribeChannel, + configApplyRequestQueue: make(map[string][]*mpi.ManagementPlaneRequest), + resource: &mpi.Resource{}, + requestsInProgress: make(map[string]*mpi.ManagementPlaneRequest), } - - return commandService } func (cs *CommandService) IsConnected() bool { @@ -176,6 +177,11 @@ func (cs *CommandService) SendDataPlaneResponse(ctx context.Context, response *m return err } + if response.GetCommandResponse().GetStatus() == mpi.CommandResponse_COMMAND_STATUS_OK || + response.GetCommandResponse().GetStatus() == mpi.CommandResponse_COMMAND_STATUS_FAILURE { + delete(cs.requestsInProgress, response.GetMessageMeta().GetCorrelationId()) + } + return backoff.Retry( cs.sendDataPlaneResponseCallback(ctx, response), backoffHelpers.Context(backOffCtx, cs.agentConfig.Client.Backoff), @@ -256,6 +262,33 @@ func (cs *CommandService) CreateConnection( } func (cs *CommandService) UpdateClient(ctx context.Context, client mpi.CommandServiceClient) error { + cs.connectionResetInProgress.Store(true) + defer cs.connectionResetInProgress.Store(false) + + // Wait for any in-progress requests to complete before updating the client + start := time.Now() + + for len(cs.requestsInProgress) > 0 { + if time.Since(start) >= cs.agentConfig.Client.Grpc.ConnectionResetTimeout { + slog.WarnContext( + ctx, + "Timeout reached while waiting for in-progress requests to complete", + "number_of_requests_in_progress", len(cs.requestsInProgress), + ) + + break + } + + slog.InfoContext( + ctx, + "Waiting for in-progress requests to complete before updating command service gRPC client", + "max_wait_time", cs.agentConfig.Client.Grpc.ConnectionResetTimeout, + "number_of_requests_in_progress", len(cs.requestsInProgress), + ) + + time.Sleep(timeToWaitBetweenChecks) + } + cs.subscribeClientMutex.Lock() cs.commandServiceClient = client cs.subscribeClientMutex.Unlock() @@ -363,7 +396,7 @@ func (cs *CommandService) sendResponseForQueuedConfigApplyRequests( cs.configApplyRequestQueue[instanceID] = cs.configApplyRequestQueue[instanceID][indexOfConfigApplyRequest+1:] slog.DebugContext(ctx, "Removed config apply requests from queue", "queue", cs.configApplyRequestQueue[instanceID]) - if len(cs.configApplyRequestQueue[instanceID]) > 0 { + if len(cs.configApplyRequestQueue[instanceID]) > 0 && !cs.connectionResetInProgress.Load() { cs.subscribeChannel <- cs.configApplyRequestQueue[instanceID][len(cs.configApplyRequestQueue[instanceID])-1] } @@ -404,6 +437,12 @@ func (cs *CommandService) dataPlaneHealthCallback( //nolint:revive // cognitive complexity is 18 func (cs *CommandService) receiveCallback(ctx context.Context) func() error { return func() error { + if cs.connectionResetInProgress.Load() { + slog.DebugContext(ctx, "Connection reset in progress, skipping receive from subscribe stream") + + return nil + } + cs.subscribeClientMutex.Lock() if cs.subscribeClient == nil { @@ -444,6 +483,8 @@ func (cs *CommandService) receiveCallback(ctx context.Context) func() error { default: cs.subscribeChannel <- request } + + cs.requestsInProgress[request.GetMessageMeta().GetCorrelationId()] = request } return nil @@ -476,7 +517,7 @@ func (cs *CommandService) queueConfigApplyRequests(ctx context.Context, request instanceID := request.GetConfigApplyRequest().GetOverview().GetConfigVersion().GetInstanceId() cs.configApplyRequestQueue[instanceID] = append(cs.configApplyRequestQueue[instanceID], request) - if len(cs.configApplyRequestQueue[instanceID]) == 1 { + if len(cs.configApplyRequestQueue[instanceID]) == 1 && !cs.connectionResetInProgress.Load() { cs.subscribeChannel <- request } else { slog.DebugContext( diff --git a/internal/config/config.go b/internal/config/config.go index 75c829e64..3de1f7c9f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -623,6 +623,12 @@ func registerClientFlags(fs *flag.FlagSet) { "File chunk size in bytes.", ) + fs.Duration( + ClientGRPCConnectionResetTimeoutKey, + DefGRPCConnectionResetTimeout, + "Duration to wait for in-progress management plane requests to complete before resetting the gRPC connection.", + ) + fs.Uint32( ClientGRPCMaxFileSizeKey, DefMaxFileSize, @@ -1112,6 +1118,7 @@ func resolveClient() *Client { MaxFileSize: viperInstance.GetUint32(ClientGRPCMaxFileSizeKey), FileChunkSize: viperInstance.GetUint32(ClientGRPCFileChunkSizeKey), MaxParallelFileOperations: viperInstance.GetInt(ClientGRPCMaxParallelFileOperationsKey), + ConnectionResetTimeout: viperInstance.GetDuration(ClientGRPCConnectionResetTimeoutKey), }, Backoff: &BackOff{ InitialInterval: viperInstance.GetDuration(ClientBackoffInitialIntervalKey), diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 0f1e08075..acdaad411 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -60,12 +60,13 @@ const ( DefAuxiliaryCommandTLServerNameKey = "" // Client GRPC Settings - DefMaxMessageSize = 0 // 0 = unset - DefMaxMessageRecieveSize = 4194304 // default 4 MB - DefMaxMessageSendSize = 4194304 // default 4 MB - DefMaxFileSize uint32 = 1048576 // 1MB - DefFileChunkSize uint32 = 524288 // 0.5MB - DefMaxParallelFileOperations = 5 + DefMaxMessageSize = 0 // 0 = unset + DefMaxMessageRecieveSize = 4194304 // default 4 MB + DefMaxMessageSendSize = 4194304 // default 4 MB + DefMaxFileSize uint32 = 1048576 // 1MB + DefFileChunkSize uint32 = 524288 // 0.5MB + DefMaxParallelFileOperations = 5 + DefGRPCConnectionResetTimeout = 3 * time.Minute // Client HTTP Settings DefHTTPTimeout = 10 * time.Second diff --git a/internal/config/flags.go b/internal/config/flags.go index d0f664540..ef89e5189 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -41,6 +41,7 @@ var ( ClientGRPCMaxFileSizeKey = pre(ClientRootKey) + "grpc_max_file_size" ClientGRPCFileChunkSizeKey = pre(ClientRootKey) + "grpc_file_chunk_size" ClientGRPCMaxParallelFileOperationsKey = pre(ClientRootKey) + "grpc_max_parallel_file_operations" + ClientGRPCConnectionResetTimeoutKey = pre(ClientRootKey) + "grpc_connection_reset_timeout" ClientBackoffInitialIntervalKey = pre(ClientRootKey) + "backoff_initial_interval" ClientBackoffMaxIntervalKey = pre(ClientRootKey) + "backoff_max_interval" diff --git a/internal/config/types.go b/internal/config/types.go index 72eda1369..59d2f64d2 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -96,12 +96,13 @@ type ( KeepAlive *KeepAlive `yaml:"keepalive" mapstructure:"keepalive"` // if MaxMessageSize is size set then we use that value, // otherwise MaxMessageRecieveSize and MaxMessageSendSize for individual settings - MaxMessageSize int `yaml:"max_message_size" mapstructure:"max_message_size"` - MaxMessageReceiveSize int `yaml:"max_message_receive_size" mapstructure:"max_message_receive_size"` - MaxMessageSendSize int `yaml:"max_message_send_size" mapstructure:"max_message_send_size"` - MaxFileSize uint32 `yaml:"max_file_size" mapstructure:"max_file_size"` - FileChunkSize uint32 `yaml:"file_chunk_size" mapstructure:"file_chunk_size"` - MaxParallelFileOperations int `yaml:"max_parallel_file_operations" mapstructure:"max_parallel_file_operations"` + MaxMessageSize int `yaml:"max_message_size" mapstructure:"max_message_size"` + MaxMessageReceiveSize int `yaml:"max_message_receive_size" mapstructure:"max_message_receive_size"` + MaxMessageSendSize int `yaml:"max_message_send_size" mapstructure:"max_message_send_size"` + MaxFileSize uint32 `yaml:"max_file_size" mapstructure:"max_file_size"` + FileChunkSize uint32 `yaml:"file_chunk_size" mapstructure:"file_chunk_size"` + MaxParallelFileOperations int `yaml:"max_parallel_file_operations" mapstructure:"max_parallel_file_operations"` + ConnectionResetTimeout time.Duration `yaml:"connection_reset_timeout" mapstructure:"connection_reset_timeout"` } KeepAlive struct {