Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions internal/command/command_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ func (cp *CommandPlugin) Process(ctx context.Context, msg *bus.Message) {
if logger.ServerType(ctxWithMetadata) == cp.commandServerType.String() {
switch msg.Topic {
case bus.ConnectionResetTopic:
cp.processConnectionReset(ctxWithMetadata, msg)
// Running as a separate go routine so that the command plugin can continue to process data plane responses
// while the connection reset is in progress
go cp.processConnectionReset(ctxWithMetadata, msg)
case bus.ResourceUpdateTopic:
cp.processResourceUpdate(ctxWithMetadata, msg)
case bus.InstanceHealthTopic:
Expand Down Expand Up @@ -232,11 +234,19 @@ func (cp *CommandPlugin) processConnectionReset(ctx context.Context, msg *bus.Me
slog.DebugContext(ctx, "Command plugin received connection reset message")

if newConnection, ok := msg.Data.(grpc.GrpcConnectionInterface); ok {
slog.DebugContext(ctx, "Canceling Subscribe after connection reset")
ctxWithMetadata := cp.config.NewContextWithLabels(ctx)
cp.subscribeMutex.Lock()
defer cp.subscribeMutex.Unlock()

// Update the command service with the new client first
err := cp.commandService.UpdateClient(ctxWithMetadata, newConnection.CommandServiceClient())
if err != nil {
slog.ErrorContext(ctx, "Failed to reset connection", "error", err)
return
}

// Once the command service is updated, we close the old connection
slog.DebugContext(ctx, "Canceling Subscribe after connection reset")
if cp.subscribeCancel != nil {
cp.subscribeCancel()
slog.DebugContext(ctxWithMetadata, "Successfully canceled subscribe after connection reset")
Expand All @@ -248,12 +258,6 @@ func (cp *CommandPlugin) processConnectionReset(ctx context.Context, msg *bus.Me
}

cp.conn = newConnection
err := cp.commandService.UpdateClient(ctx, cp.conn.CommandServiceClient())
if err != nil {
slog.ErrorContext(ctx, "Failed to reset connection", "error", err)
return
}

slog.DebugContext(ctxWithMetadata, "Starting new subscribe after connection reset")
subscribeCtx, cp.subscribeCancel = context.WithCancel(ctxWithMetadata)
go cp.commandService.Subscribe(subscribeCtx)
Expand Down
9 changes: 8 additions & 1 deletion internal/command/command_plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,14 @@ func TestCommandPlugin_Process(t *testing.T) {
Topic: bus.ConnectionResetTopic,
Data: commandPlugin.conn,
})
require.Equal(t, 1, fakeCommandService.UpdateClientCallCount())

// Separate goroutine is executed so need to wait for it to complete
assert.Eventually(
t,
func() bool { return fakeCommandService.UpdateClientCallCount() == 1 },
2*time.Second,
10*time.Millisecond,
)
}

func TestCommandPlugin_monitorSubscribeChannel(t *testing.T) {
Expand Down
69 changes: 55 additions & 14 deletions internal/command/command_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"log/slog"
"sync"
"sync/atomic"
"time"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand All @@ -33,6 +34,7 @@ var _ commandService = (*CommandService)(nil)

const (
createConnectionMaxElapsedTime = 0
timeToWaitBetweenChecks = 5 * time.Second
)

type (
Expand All @@ -41,8 +43,10 @@ type (
subscribeClient mpi.CommandService_SubscribeClient
agentConfig *config.Config
isConnected *atomic.Bool
connectionResetInProgress *atomic.Bool
subscribeChannel chan *mpi.ManagementPlaneRequest
configApplyRequestQueue map[string][]*mpi.ManagementPlaneRequest // key is the instance ID
requestsInProgress map[string]*mpi.ManagementPlaneRequest // key is the correlation ID
resource *mpi.Resource
subscribeClientMutex sync.Mutex
configApplyRequestQueueMutex sync.Mutex
Expand All @@ -55,19 +59,16 @@ func NewCommandService(
agentConfig *config.Config,
subscribeChannel chan *mpi.ManagementPlaneRequest,
) *CommandService {
isConnected := &atomic.Bool{}
isConnected.Store(false)

commandService := &CommandService{
commandServiceClient: commandServiceClient,
agentConfig: agentConfig,
isConnected: isConnected,
subscribeChannel: subscribeChannel,
configApplyRequestQueue: make(map[string][]*mpi.ManagementPlaneRequest),
resource: &mpi.Resource{},
return &CommandService{
commandServiceClient: commandServiceClient,
agentConfig: agentConfig,
isConnected: &atomic.Bool{},
connectionResetInProgress: &atomic.Bool{},
subscribeChannel: subscribeChannel,
configApplyRequestQueue: make(map[string][]*mpi.ManagementPlaneRequest),
resource: &mpi.Resource{},
requestsInProgress: make(map[string]*mpi.ManagementPlaneRequest),
}

return commandService
}

func (cs *CommandService) IsConnected() bool {
Expand Down Expand Up @@ -176,6 +177,11 @@ func (cs *CommandService) SendDataPlaneResponse(ctx context.Context, response *m
return err
}

if response.GetCommandResponse().GetStatus() == mpi.CommandResponse_COMMAND_STATUS_OK ||
response.GetCommandResponse().GetStatus() == mpi.CommandResponse_COMMAND_STATUS_FAILURE {
delete(cs.requestsInProgress, response.GetMessageMeta().GetCorrelationId())
}

return backoff.Retry(
cs.sendDataPlaneResponseCallback(ctx, response),
backoffHelpers.Context(backOffCtx, cs.agentConfig.Client.Backoff),
Expand Down Expand Up @@ -256,6 +262,33 @@ func (cs *CommandService) CreateConnection(
}

func (cs *CommandService) UpdateClient(ctx context.Context, client mpi.CommandServiceClient) error {
cs.connectionResetInProgress.Store(true)
defer cs.connectionResetInProgress.Store(false)

// Wait for any in-progress requests to complete before updating the client
start := time.Now()

for len(cs.requestsInProgress) > 0 {
if time.Since(start) >= cs.agentConfig.Client.Grpc.ConnectionResetTimeout {
slog.WarnContext(
ctx,
"Timeout reached while waiting for in-progress requests to complete",
"number_of_requests_in_progress", len(cs.requestsInProgress),
)

break
}

slog.InfoContext(
ctx,
"Waiting for in-progress requests to complete before updating command service gRPC client",
"max_wait_time", cs.agentConfig.Client.Grpc.ConnectionResetTimeout,
"number_of_requests_in_progress", len(cs.requestsInProgress),
)

time.Sleep(timeToWaitBetweenChecks)
}

cs.subscribeClientMutex.Lock()
cs.commandServiceClient = client
cs.subscribeClientMutex.Unlock()
Expand Down Expand Up @@ -363,7 +396,7 @@ func (cs *CommandService) sendResponseForQueuedConfigApplyRequests(
cs.configApplyRequestQueue[instanceID] = cs.configApplyRequestQueue[instanceID][indexOfConfigApplyRequest+1:]
slog.DebugContext(ctx, "Removed config apply requests from queue", "queue", cs.configApplyRequestQueue[instanceID])

if len(cs.configApplyRequestQueue[instanceID]) > 0 {
if len(cs.configApplyRequestQueue[instanceID]) > 0 && !cs.connectionResetInProgress.Load() {
cs.subscribeChannel <- cs.configApplyRequestQueue[instanceID][len(cs.configApplyRequestQueue[instanceID])-1]
}

Expand Down Expand Up @@ -404,6 +437,12 @@ func (cs *CommandService) dataPlaneHealthCallback(
//nolint:revive // cognitive complexity is 18
func (cs *CommandService) receiveCallback(ctx context.Context) func() error {
return func() error {
if cs.connectionResetInProgress.Load() {
slog.DebugContext(ctx, "Connection reset in progress, skipping receive from subscribe stream")

return nil
}

cs.subscribeClientMutex.Lock()

if cs.subscribeClient == nil {
Expand Down Expand Up @@ -444,6 +483,8 @@ func (cs *CommandService) receiveCallback(ctx context.Context) func() error {
default:
cs.subscribeChannel <- request
}

cs.requestsInProgress[request.GetMessageMeta().GetCorrelationId()] = request
}

return nil
Expand Down Expand Up @@ -476,7 +517,7 @@ func (cs *CommandService) queueConfigApplyRequests(ctx context.Context, request

instanceID := request.GetConfigApplyRequest().GetOverview().GetConfigVersion().GetInstanceId()
cs.configApplyRequestQueue[instanceID] = append(cs.configApplyRequestQueue[instanceID], request)
if len(cs.configApplyRequestQueue[instanceID]) == 1 {
if len(cs.configApplyRequestQueue[instanceID]) == 1 && !cs.connectionResetInProgress.Load() {
cs.subscribeChannel <- request
} else {
slog.DebugContext(
Expand Down
7 changes: 7 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,12 @@ func registerClientFlags(fs *flag.FlagSet) {
"File chunk size in bytes.",
)

fs.Duration(
ClientGRPCConnectionResetTimeoutKey,
DefGRPCConnectionResetTimeout,
"Duration to wait for in-progress management plane requests to complete before resetting the gRPC connection.",
)

fs.Uint32(
ClientGRPCMaxFileSizeKey,
DefMaxFileSize,
Expand Down Expand Up @@ -1112,6 +1118,7 @@ func resolveClient() *Client {
MaxFileSize: viperInstance.GetUint32(ClientGRPCMaxFileSizeKey),
FileChunkSize: viperInstance.GetUint32(ClientGRPCFileChunkSizeKey),
MaxParallelFileOperations: viperInstance.GetInt(ClientGRPCMaxParallelFileOperationsKey),
ConnectionResetTimeout: viperInstance.GetDuration(ClientGRPCConnectionResetTimeoutKey),
},
Backoff: &BackOff{
InitialInterval: viperInstance.GetDuration(ClientBackoffInitialIntervalKey),
Expand Down
13 changes: 7 additions & 6 deletions internal/config/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,13 @@ const (
DefAuxiliaryCommandTLServerNameKey = ""

// Client GRPC Settings
DefMaxMessageSize = 0 // 0 = unset
DefMaxMessageRecieveSize = 4194304 // default 4 MB
DefMaxMessageSendSize = 4194304 // default 4 MB
DefMaxFileSize uint32 = 1048576 // 1MB
DefFileChunkSize uint32 = 524288 // 0.5MB
DefMaxParallelFileOperations = 5
DefMaxMessageSize = 0 // 0 = unset
DefMaxMessageRecieveSize = 4194304 // default 4 MB
DefMaxMessageSendSize = 4194304 // default 4 MB
DefMaxFileSize uint32 = 1048576 // 1MB
DefFileChunkSize uint32 = 524288 // 0.5MB
DefMaxParallelFileOperations = 5
DefGRPCConnectionResetTimeout = 3 * time.Minute

// Client HTTP Settings
DefHTTPTimeout = 10 * time.Second
Expand Down
1 change: 1 addition & 0 deletions internal/config/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ var (
ClientGRPCMaxFileSizeKey = pre(ClientRootKey) + "grpc_max_file_size"
ClientGRPCFileChunkSizeKey = pre(ClientRootKey) + "grpc_file_chunk_size"
ClientGRPCMaxParallelFileOperationsKey = pre(ClientRootKey) + "grpc_max_parallel_file_operations"
ClientGRPCConnectionResetTimeoutKey = pre(ClientRootKey) + "grpc_connection_reset_timeout"

ClientBackoffInitialIntervalKey = pre(ClientRootKey) + "backoff_initial_interval"
ClientBackoffMaxIntervalKey = pre(ClientRootKey) + "backoff_max_interval"
Expand Down
13 changes: 7 additions & 6 deletions internal/config/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,13 @@ type (
KeepAlive *KeepAlive `yaml:"keepalive" mapstructure:"keepalive"`
// if MaxMessageSize is size set then we use that value,
// otherwise MaxMessageRecieveSize and MaxMessageSendSize for individual settings
MaxMessageSize int `yaml:"max_message_size" mapstructure:"max_message_size"`
MaxMessageReceiveSize int `yaml:"max_message_receive_size" mapstructure:"max_message_receive_size"`
MaxMessageSendSize int `yaml:"max_message_send_size" mapstructure:"max_message_send_size"`
MaxFileSize uint32 `yaml:"max_file_size" mapstructure:"max_file_size"`
FileChunkSize uint32 `yaml:"file_chunk_size" mapstructure:"file_chunk_size"`
MaxParallelFileOperations int `yaml:"max_parallel_file_operations" mapstructure:"max_parallel_file_operations"`
MaxMessageSize int `yaml:"max_message_size" mapstructure:"max_message_size"`
MaxMessageReceiveSize int `yaml:"max_message_receive_size" mapstructure:"max_message_receive_size"`
MaxMessageSendSize int `yaml:"max_message_send_size" mapstructure:"max_message_send_size"`
MaxFileSize uint32 `yaml:"max_file_size" mapstructure:"max_file_size"`
FileChunkSize uint32 `yaml:"file_chunk_size" mapstructure:"file_chunk_size"`
MaxParallelFileOperations int `yaml:"max_parallel_file_operations" mapstructure:"max_parallel_file_operations"`
ConnectionResetTimeout time.Duration `yaml:"connection_reset_timeout" mapstructure:"connection_reset_timeout"`
}

KeepAlive struct {
Expand Down
Loading