diff --git a/cmd/add_test.go b/cmd/add_test.go index ef5e88f..845107e 100644 --- a/cmd/add_test.go +++ b/cmd/add_test.go @@ -142,7 +142,7 @@ func TestAddCmd_RegistryFails(t *testing.T) { cmdopts.WithRegistryBuilder(&fakeBuilder{err: errors.New("registry error")}), ) require.NoError(t, err) - + cmdObj.SetOut(io.Discard) cmdObj.SetArgs([]string{"server1"}) err = cmdObj.Execute() require.Error(t, err) diff --git a/cmd/config/daemon/validate_test.go b/cmd/config/daemon/validate_test.go index 6b2c6a5..68ad89d 100644 --- a/cmd/config/daemon/validate_test.go +++ b/cmd/config/daemon/validate_test.go @@ -19,7 +19,7 @@ type mockValidateConfigLoader struct { err error } -func (m *mockValidateConfigLoader) Load(path string) (config.Modifier, error) { +func (m *mockValidateConfigLoader) Load(_ string) (config.Modifier, error) { if m.err != nil { return nil, m.err } @@ -159,7 +159,7 @@ func TestValidateCmd_ConfigLoadError(t *testing.T) { // Assertions require.Error(t, err) - require.Contains(t, err.Error(), "failed to load config") + require.EqualError(t, err, "mock config load error") } func TestValidateCmd_MultipleValidationErrors(t *testing.T) { diff --git a/cmd/config/export/export.go b/cmd/config/export/export.go index 1332625..48b35a1 100644 --- a/cmd/config/export/export.go +++ b/cmd/config/export/export.go @@ -101,7 +101,7 @@ func (c *Cmd) run(cmd *cobra.Command, _ []string) error { func (c *Cmd) handleExport() error { cfg, err := c.cfgLoader.Load(flags.ConfigFile) if err != nil { - return fmt.Errorf("failed to load config: %w", err) + return err } rtCtx, err := c.ctxLoader.Load(flags.RuntimeFile) diff --git a/cmd/config/export/export_test.go b/cmd/config/export/export_test.go index a2d6da2..4eb9513 100644 --- a/cmd/config/export/export_test.go +++ b/cmd/config/export/export_test.go @@ -1,6 +1,7 @@ package export import ( + "io" "os" "path/filepath" "strings" @@ -115,6 +116,9 @@ func TestExportCommand_Integration(t *testing.T) { exportCmd, err := NewCmd(&cmd.BaseCmd{}) require.NoError(t, err) + // Mute the terminal output. + exportCmd.SetOut(io.Discard) + // Set command-specific flags require.NoError(t, exportCmd.Flags().Set("context-output", contextOutput)) require.NoError(t, exportCmd.Flags().Set("contract-output", contractOutput)) diff --git a/cmd/daemon.go b/cmd/daemon.go index f5d2831..2558212 100644 --- a/cmd/daemon.go +++ b/cmd/daemon.go @@ -282,7 +282,7 @@ func newDaemonCobraCmd(daemonCmd *DaemonCmd) *cobra.Command { cobraCommand.MarkFlagsMutuallyExclusive("dev", flagAddr) - // Note: Additional CORS validation required to check CORS flags are present alongside --cors-enable. + // NOTE: Additional CORS validation required to check CORS flags are present alongside --cors-enable. cobraCommand.MarkFlagsRequiredTogether(flagCORSEnable, flagCORSOrigin) return cobraCommand @@ -315,10 +315,16 @@ func (c *DaemonCmd) run(cmd *cobra.Command, _ []string) error { return err } + // Load the new configuration. + cfg, err := c.LoadConfig(c.cfgLoader) + if err != nil { + return fmt.Errorf("%w: %w", config.ErrConfigLoadFailed, err) + } + // Load configuration layers (config file, then flag overrides). - warnings, err := c.loadConfigurationLayers(logger, cmd) + warnings, err := c.loadConfigurationLayers(logger, cmd, cfg) if err != nil { - return fmt.Errorf("failed to load configuration: %w", err) + return err } if c.dev && len(warnings) > 0 { @@ -342,10 +348,14 @@ func (c *DaemonCmd) run(cmd *cobra.Command, _ []string) error { return err } - // Load runtime servers from config and context. - runtimeServers, err := c.loadRuntimeServers() + execCtx, err := c.ctxLoader.Load(flags.RuntimeFile) if err != nil { - return fmt.Errorf("error loading runtime servers: %w", err) + return fmt.Errorf("failed to load runtime context: %w", err) + } + + runtimeServers, err := runtime.AggregateConfigs(cfg, execCtx) + if err != nil { + return fmt.Errorf("failed to aggregate configs: %w", err) } deps, err := daemon.NewDependencies(logger, addr, runtimeServers) @@ -368,16 +378,22 @@ func (c *DaemonCmd) run(cmd *cobra.Command, _ []string) error { return fmt.Errorf("failed to create mcpd daemon instance: %w", err) } - daemonCtx, daemonCtxCancel := signal.NotifyContext( - context.Background(), - os.Interrupt, - syscall.SIGTERM, syscall.SIGINT, - ) - defer daemonCtxCancel() + // Create signal contexts for shutdown and reload. + shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) + defer shutdownCancel() + + // Setup signal handling. + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT, syscall.SIGHUP) + defer signal.Stop(sigChan) + + // Create reload channel for SIGHUP handling. + reloadChan := make(chan struct{}, 1) + defer close(reloadChan) // Ensure channel is closed on function exit runErr := make(chan error, 1) go func() { - if err := d.StartAndManage(daemonCtx); err != nil && !errors.Is(err, context.Canceled) { + if err := d.StartAndManage(shutdownCtx); err != nil && !errors.Is(err, context.Canceled) { runErr <- err } close(runErr) @@ -387,15 +403,36 @@ func (c *DaemonCmd) run(cmd *cobra.Command, _ []string) error { c.printDevBanner(cmd.OutOrStdout(), logger, addr) } - select { - case <-daemonCtx.Done(): - logger.Info("Shutting down daemon...") - err := <-runErr // Wait for cleanup and deferred logging - logger.Info("Shutdown complete") - return err // Graceful Ctrl+C / SIGTERM - case err := <-runErr: - logger.Error("daemon exited with error", "error", err) - return err // Propagate daemon failure + // Start signal handling in background. + go c.handleSignals(logger, sigChan, reloadChan, shutdownCancel) + + // Start the daemon's main loop which responds to reloads, shutdowns and startup errors. + for { + select { + case <-reloadChan: + logger.Info("Reloading servers...") + if err := c.reloadServers(shutdownCtx, d); err != nil { + logger.Error("Failed to reload servers, exiting to prevent inconsistent state", "error", err) + // Signal shutdown to exit cleanly. + shutdownCancel() + return fmt.Errorf("configuration reload failed with unrecoverable error: %w", err) + } + + logger.Info("Configuration reloaded successfully") + case <-shutdownCtx.Done(): + logger.Info("Shutting down daemon...") + err := <-runErr // Wait for cleanup and deferred logging + logger.Info("Shutdown complete") + + return err // Graceful shutdown + case err := <-runErr: + if err != nil { + logger.Error("daemon exited with error", "error", err) + return err // Propagate daemon failure + } + + return nil + } } } @@ -428,19 +465,17 @@ func formatValue(value any) string { // It follows the precedence order: flags > config file > defaults. // CLI flags override config file values when explicitly set. // Returns warnings for each flag override and any error encountered. -func (c *DaemonCmd) loadConfigurationLayers(logger hclog.Logger, cmd *cobra.Command) ([]string, error) { - cfgModifier, err := c.cfgLoader.Load(flags.ConfigFile) - if err != nil { - return nil, err - } - - cfg, ok := cfgModifier.(*config.Config) - if !ok { - return nil, fmt.Errorf("config file contains invalid configuration structure") +func (c *DaemonCmd) loadConfigurationLayers( + logger hclog.Logger, + cmd *cobra.Command, + cfg *config.Config, +) ([]string, error) { + if cfg == nil { + return nil, fmt.Errorf("config data not present, cannot apply configuration layers") } + // No daemon config section - flags and defaults will be used. if cfg.Daemon == nil { - // No daemon config section - flags and defaults will be used. return nil, nil } @@ -728,24 +763,62 @@ func (c *DaemonCmd) loadConfigCORS(cors *config.CORSConfigSection, logger hclog. return warnings } -// loadRuntimeServers loads the configuration and aggregates it with runtime context to produce runtime servers. -func (c *DaemonCmd) loadRuntimeServers() ([]runtime.Server, error) { - cfgModifier, err := c.cfgLoader.Load(flags.ConfigFile) +// handleSignals processes OS signals for daemon lifecycle management. +// This function is intended to be called in a dedicated goroutine. +// +// SIGHUP signals trigger configuration reloads via reloadChan. +// Termination signals (SIGTERM, SIGINT, os.Interrupt) trigger graceful shutdown via shutdownCancel. +// The function runs until a shutdown signal is received or sigChan is closed. +// Non-blocking sends to reloadChan prevent duplicate reload requests. +func (c *DaemonCmd) handleSignals( + logger hclog.Logger, + sigChan <-chan os.Signal, + reloadChan chan<- struct{}, + shutdownCancel context.CancelFunc, +) { + for sig := range sigChan { + switch sig { + case syscall.SIGHUP: + logger.Info("Received SIGHUP, triggering config reload") + select { + case reloadChan <- struct{}{}: + // Reload signal sent. + default: + // Reload already pending, skip. + logger.Warn("Config reload already in progress, skipping") + } + case os.Interrupt, syscall.SIGTERM, syscall.SIGINT: + logger.Info("Received shutdown signal", "signal", sig) + shutdownCancel() + return + } + } +} + +// reloadServers reloads server configuration from config files. +// This method only reloads runtime servers; daemon config changes require a restart. +func (c *DaemonCmd) reloadServers(ctx context.Context, d *daemon.Daemon) error { + cfg, err := c.LoadConfig(c.cfgLoader) if err != nil { - return nil, fmt.Errorf("failed to load config: %w", err) + return fmt.Errorf("%w: %w", config.ErrConfigLoadFailed, err) } execCtx, err := c.ctxLoader.Load(flags.RuntimeFile) if err != nil { - return nil, fmt.Errorf("failed to load runtime context: %w", err) + return fmt.Errorf("failed to load runtime context: %w", err) } - servers, err := runtime.AggregateConfigs(cfgModifier, execCtx) + newServers, err := runtime.AggregateConfigs(cfg, execCtx) if err != nil { - return nil, fmt.Errorf("failed to aggregate configs: %w", err) + return fmt.Errorf("failed to aggregate configs: %w", err) + } + + // Reload the servers in the daemon. + if err := d.ReloadServers(ctx, newServers); err != nil { + return fmt.Errorf("failed to reload servers: %w", err) } - return servers, nil + return nil } // validateFlags validates the command flags and their relationships. diff --git a/cmd/daemon_test.go b/cmd/daemon_test.go index 36b3890..bc8a9e7 100644 --- a/cmd/daemon_test.go +++ b/cmd/daemon_test.go @@ -4,6 +4,8 @@ import ( "bytes" "fmt" "io" + "os" + "syscall" "testing" "time" @@ -183,11 +185,11 @@ func TestDaemon_NewDaemonCmd_WithOptions(t *testing.T) { if tc.wantErr != "" { require.Error(t, err) - assert.Contains(t, err.Error(), tc.wantErr) - assert.Nil(t, cobraCmd) + require.EqualError(t, err, tc.wantErr) + require.Nil(t, cobraCmd) } else { require.NoError(t, err) - assert.NotNil(t, cobraCmd) + require.NotNil(t, cobraCmd) } }) } @@ -251,7 +253,7 @@ func TestDaemon_DaemonCmd_FlagMutualExclusion(t *testing.T) { err = cobraCmd.Execute() require.Error(t, err) - assert.Contains(t, err.Error(), "if any flags in the group [dev addr] are set none of the others can be") + require.Contains(t, err.Error(), "if any flags in the group [dev addr] are set none of the others can be") } func TestDaemon_DaemonCmd_ValidateFlags(t *testing.T) { @@ -288,7 +290,7 @@ func TestDaemon_DaemonCmd_ValidateFlags(t *testing.T) { maxAge: "invalid-duration", }, }, - expectError: "invalid --cors-max-age duration", + expectError: "invalid --cors-max-age duration: time: invalid duration \"invalid-duration\"", }, { name: "invalid API shutdown timeout", @@ -297,7 +299,7 @@ func TestDaemon_DaemonCmd_ValidateFlags(t *testing.T) { apiShutdown: "not-a-duration", }, }, - expectError: "invalid --timeout-api-shutdown duration", + expectError: "invalid --timeout-api-shutdown duration: time: invalid duration \"not-a-duration\"", }, { name: "invalid MCP init timeout", @@ -306,7 +308,7 @@ func TestDaemon_DaemonCmd_ValidateFlags(t *testing.T) { mcpInit: "invalid", }, }, - expectError: "invalid --timeout-mcp-init duration", + expectError: "invalid --timeout-mcp-init duration: time: invalid duration \"invalid\"", }, { name: "invalid health check timeout", @@ -315,7 +317,7 @@ func TestDaemon_DaemonCmd_ValidateFlags(t *testing.T) { healthCheck: "bad-format", }, }, - expectError: "invalid --timeout-mcp-health duration", + expectError: "invalid --timeout-mcp-health duration: time: invalid duration \"bad-format\"", }, { name: "invalid health check interval", @@ -324,7 +326,7 @@ func TestDaemon_DaemonCmd_ValidateFlags(t *testing.T) { healthCheck: "not-valid", }, }, - expectError: "invalid --interval-mcp-health duration", + expectError: "invalid --interval-mcp-health duration: time: invalid duration \"not-valid\"", }, } @@ -352,7 +354,7 @@ func TestDaemon_DaemonCmd_ValidateFlags(t *testing.T) { if tc.expectError != "" { require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectError) + require.EqualError(t, err, tc.expectError) } else { require.NoError(t, err) } @@ -469,7 +471,7 @@ func TestDaemon_DaemonCmd_BuildAPIOptions(t *testing.T) { maxAge: "invalid", }, }, - expectError: "invalid cors-max-age", + expectError: "invalid cors-max-age: time: invalid duration \"invalid\"", }, { name: "invalid API shutdown timeout", @@ -478,7 +480,7 @@ func TestDaemon_DaemonCmd_BuildAPIOptions(t *testing.T) { apiShutdown: "not-valid", }, }, - expectError: "invalid timeout-api-shutdown", + expectError: "invalid timeout-api-shutdown: time: invalid duration \"not-valid\"", }, } @@ -494,7 +496,7 @@ func TestDaemon_DaemonCmd_BuildAPIOptions(t *testing.T) { if tc.expectError != "" { require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectError) + require.EqualError(t, err, tc.expectError) } else { require.NoError(t, err) if tc.validateResult != nil { @@ -540,7 +542,7 @@ func TestDaemon_DaemonCmd_BuildDaemonOptions(t *testing.T) { mcpInit: "invalid", }, }, - expectError: "invalid timeout-mcp-init", + expectError: "invalid timeout-mcp-init: time: invalid duration \"invalid\"", }, { name: "invalid health check timeout", @@ -549,7 +551,7 @@ func TestDaemon_DaemonCmd_BuildDaemonOptions(t *testing.T) { healthCheck: "bad-format", }, }, - expectError: "invalid timeout-mcp-health", + expectError: "invalid timeout-mcp-health: time: invalid duration \"bad-format\"", }, { name: "invalid health check interval", @@ -558,7 +560,7 @@ func TestDaemon_DaemonCmd_BuildDaemonOptions(t *testing.T) { healthCheck: "not-valid", }, }, - expectError: "invalid interval-mcp-health", + expectError: "invalid interval-mcp-health: time: invalid duration \"not-valid\"", }, } @@ -577,7 +579,7 @@ func TestDaemon_DaemonCmd_BuildDaemonOptions(t *testing.T) { if tc.expectError != "" { require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectError) + require.EqualError(t, err, tc.expectError) } else { require.NoError(t, err) if tc.validateResult != nil { @@ -1473,13 +1475,13 @@ func TestDaemon_LoadConfigurationLayers(t *testing.T) { }, { name: "config file error", - configError: fmt.Errorf("failed to load config"), - expectError: "failed to load config", + configError: fmt.Errorf("failed - sad times"), + expectError: "failed - sad times", }, { name: "invalid config structure", configData: nil, // This will cause type assertion to fail - expectError: "config file contains invalid configuration structure", + expectError: "config data not present, cannot apply configuration layers", }, { name: "daemon config without flags - uses config values", @@ -1588,6 +1590,8 @@ func TestDaemon_LoadConfigurationLayers(t *testing.T) { config: daemonFlagConfig{}, } + // Note: Config data will be passed directly to loadConfigurationLayers + // Create command and bind flags to struct fields (using the actual production code) command := newDaemonCobraCmd(daemonCmd) @@ -1598,11 +1602,18 @@ func TestDaemon_LoadConfigurationLayers(t *testing.T) { } logger := hclog.NewNullLogger() - warnings, err := daemonCmd.loadConfigurationLayers(logger, command) - if tc.expectError != "" { + if tc.configError != nil { + _, err := daemonCmd.LoadConfig(daemonCmd.cfgLoader) require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectError) + require.EqualError(t, err, tc.expectError) + return + } + + warnings, err := daemonCmd.loadConfigurationLayers(logger, command, tc.configData) + + if tc.expectError != "" { + require.EqualError(t, err, tc.expectError) return } @@ -1648,6 +1659,7 @@ func TestDaemon_LoadConfigurationLayers_Integration(t *testing.T) { daemonCmd := &DaemonCmd{ cfgLoader: mockLoader, config: daemonFlagConfig{}, + // Note: Config data will be passed directly to loadConfigurationLayers } command := newDaemonCobraCmd(daemonCmd) @@ -1661,7 +1673,7 @@ func TestDaemon_LoadConfigurationLayers_Integration(t *testing.T) { require.NoError(t, err) logger := hclog.NewNullLogger() - warnings, err := daemonCmd.loadConfigurationLayers(logger, command) + warnings, err := daemonCmd.loadConfigurationLayers(logger, command, configData) require.NoError(t, err) @@ -1731,7 +1743,7 @@ func (t testInvalidConfigType) RemoveServer(name string) error { retur func (t testInvalidConfigType) ListServers() []config.ServerEntry { return nil } func (t testInvalidConfigType) SaveConfig() error { return nil } -func (m *testMockConfigLoader) Load(path string) (config.Modifier, error) { +func (m *testMockConfigLoader) Load(_ string) (config.Modifier, error) { if m.err != nil { return nil, m.err } @@ -1741,3 +1753,203 @@ func (m *testMockConfigLoader) Load(path string) (config.Modifier, error) { } return m.config, nil } + +func TestDaemon_DaemonCmd_HandleSignals(t *testing.T) { + t.Parallel() + + createDaemonCmd := func(t *testing.T) *DaemonCmd { + t.Helper() + baseCmd := &cmd.BaseCmd{} + mockLoader := &mockConfigLoader{entries: []config.ServerEntry{}} + contextLoader := &configcontext.DefaultLoader{} + daemonCmd, err := newDaemonCmd(baseCmd, mockLoader, contextLoader) + require.NoError(t, err) + return daemonCmd + } + + createLogger := func() hclog.Logger { + return hclog.New(&hclog.LoggerOptions{ + Name: "test", + Level: hclog.Off, + Output: io.Discard, + }) + } + + t.Run("SIGHUP triggers reload", func(t *testing.T) { + t.Parallel() + + daemonCmd := createDaemonCmd(t) + logger := createLogger() + + sigChan := make(chan os.Signal, 1) + reloadChan := make(chan struct{}, 1) + shutdownCancel := func() {} + + // Start handleSignals in goroutine. + go daemonCmd.handleSignals(logger, sigChan, reloadChan, shutdownCancel) + + // Send SIGHUP. + sigChan <- syscall.SIGHUP + + // Verify reload signal received. + select { + case <-reloadChan: + // Expected + case <-time.After(100 * time.Millisecond): + t.Fatal("Expected reload signal not received") + } + + close(sigChan) + }) + + t.Run("duplicate SIGHUP signals are handled gracefully", func(t *testing.T) { + t.Parallel() + + daemonCmd := createDaemonCmd(t) + logger := createLogger() + + sigChan := make(chan os.Signal, 2) + reloadChan := make(chan struct{}) // No buffer - will block second send + shutdownCancel := func() {} + + // Start handleSignals in goroutine. + go daemonCmd.handleSignals(logger, sigChan, reloadChan, shutdownCancel) + + // Send two SIGHUP signals quickly. + sigChan <- syscall.SIGHUP + sigChan <- syscall.SIGHUP + + // Verify first reload signal received. + select { + case <-reloadChan: + // Expected - first signal processed + case <-time.After(100 * time.Millisecond): + t.Fatal("Expected first reload signal not received") + } + + // Second signal should be dropped (non-blocking send). + // We can't directly verify the drop, but the function should not hang. + + close(sigChan) + }) + + t.Run("SIGTERM triggers shutdown", func(t *testing.T) { + t.Parallel() + + daemonCmd := createDaemonCmd(t) + logger := createLogger() + + sigChan := make(chan os.Signal, 1) + reloadChan := make(chan struct{}, 1) + shutdownCalled := false + shutdownCancel := func() { shutdownCalled = true } + + // Start handleSignals in goroutine. + done := make(chan struct{}) + go func() { + daemonCmd.handleSignals(logger, sigChan, reloadChan, shutdownCancel) + close(done) + }() + + // Send SIGTERM. + sigChan <- syscall.SIGTERM + + // Verify function returns and shutdown is called. + select { + case <-done: + assert.True(t, shutdownCalled, "shutdown function should be called") + case <-time.After(100 * time.Millisecond): + t.Fatal("handleSignals should return after shutdown signal") + } + }) + + t.Run("SIGINT triggers shutdown", func(t *testing.T) { + t.Parallel() + + daemonCmd := createDaemonCmd(t) + logger := createLogger() + + sigChan := make(chan os.Signal, 1) + reloadChan := make(chan struct{}, 1) + shutdownCalled := false + shutdownCancel := func() { shutdownCalled = true } + + // Start handleSignals in goroutine. + done := make(chan struct{}) + go func() { + daemonCmd.handleSignals(logger, sigChan, reloadChan, shutdownCancel) + close(done) + }() + + // Send SIGINT. + sigChan <- syscall.SIGINT + + // Verify function returns and shutdown is called. + select { + case <-done: + assert.True(t, shutdownCalled, "shutdown function should be called") + case <-time.After(100 * time.Millisecond): + t.Fatal("handleSignals should return after shutdown signal") + } + }) + + t.Run("os.Interrupt triggers shutdown", func(t *testing.T) { + t.Parallel() + + daemonCmd := createDaemonCmd(t) + logger := createLogger() + + sigChan := make(chan os.Signal, 1) + reloadChan := make(chan struct{}, 1) + shutdownCalled := false + shutdownCancel := func() { shutdownCalled = true } + + // Start handleSignals in goroutine. + done := make(chan struct{}) + go func() { + daemonCmd.handleSignals(logger, sigChan, reloadChan, shutdownCancel) + close(done) + }() + + // Send os.Interrupt. + sigChan <- os.Interrupt + + // Verify function returns and shutdown is called. + select { + case <-done: + assert.True(t, shutdownCalled, "shutdown function should be called") + case <-time.After(100 * time.Millisecond): + t.Fatal("handleSignals should return after shutdown signal") + } + }) + + t.Run("channel closure terminates function", func(t *testing.T) { + t.Parallel() + + daemonCmd := createDaemonCmd(t) + logger := createLogger() + + sigChan := make(chan os.Signal) + reloadChan := make(chan struct{}, 1) + shutdownCalled := false + shutdownCancel := func() { shutdownCalled = true } + + // Start handleSignals in goroutine. + done := make(chan struct{}) + go func() { + daemonCmd.handleSignals(logger, sigChan, reloadChan, shutdownCancel) + close(done) + }() + + // Close signal channel. + close(sigChan) + + // Verify function returns without calling shutdown. + select { + case <-done: + assert.False(t, shutdownCalled, "shutdown should not be called on channel closure") + case <-time.After(100 * time.Millisecond): + t.Fatal("handleSignals should return after channel closure") + } + }) +} diff --git a/docs/configuration.md b/docs/configuration.md index abf5132..bd6b073 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -24,11 +24,12 @@ You can provide this path in multiple ways: [[servers]] name = "fetch" package = "uvx::mcp-server-fetch@2025.4.7" + tools = ["fetch"] [[servers]] name = "time" - package = "uvx::mcp-server-time@0.6.2" - tools = ["get_current_time"] + package = "uvx::mcp-server-time@2025.8.4" + tools = ["get_current_time", "convert_time"] ``` --- @@ -65,6 +66,150 @@ Options: --- +## Hot Reload + +The `mcpd` daemon supports hot-reloading of MCP server configurations without requiring a full restart. This allows you to add, remove, or modify server configurations while keeping the daemon running. + +Hot reload processes both: + +- **Server configuration** (`--config-file`) e.g. `.mcpd.toml` +- **Execution context** (`--runtime-file`) e.g. `secrets.dev.toml` + +### SIGHUP Signal + +Send a `SIGHUP` signal to the running daemon process to trigger a configuration reload: + +```bash +# Find the daemon process ID +ps aux | grep mcpd + +# Send reload signal (replace PID with actual process ID) +kill -HUP +``` + +### Reload Behavior + +During a hot reload, the daemon intelligently categorizes changes and responds accordingly: + +| Change Type | Action | Description | +|-----------------------|----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Unchanged servers | Preserve | Servers with identical configurations keep their existing connections, tools, and health status | +| Removed servers | Stop | Servers no longer in the config file are gracefully shut down | +| New servers | Start | Newly added servers are initialized and connected | +| 'Tools-Only' changes | Update | When only the `tools` change, the daemon updates the allowed tools without restarting the server process | +| Configuration changes | Restart | Servers with other configuration changes (package version, environment variables, arguments, execution context, etc.) are stopped and restarted with new settings | + +### Example: 'Tools-Only' Update + +Consider this server configuration: + +```toml +[[servers]] + name = "github" + package = "uvx::modelcontextprotocol/github-server@1.2.3" + tools = ["create_repository", "get_repository"] +``` + +If you modify only the tools list: + +```toml +[[servers]] + name = "github" + package = "uvx::modelcontextprotocol/github-server@1.2.3" + tools = ["create_repository", "get_repository", "list_repositories"] # Additional tools +``` + +The daemon will: + +1. Detect that only the `tools` array changed +2. Update the allowed tools list in-place +3. Keep the existing server process and connections intact +4. Log a message that tools for a server were updated (including the server name and list of tools) + +### Example: Package Version Update + +If you change the package version: + +```toml +[[servers]] + name = "github" + package = "uvx::modelcontextprotocol/github-server@1.3.0" # Version changed + tools = ["create_repository", "get_repository", "list_repositories"] +``` + +The daemon will: + +1. Detect configuration changes beyond just tools +2. Gracefully stop the existing server +3. Start a new server with the updated configuration +4. Log a message that the server is being restarted (including the server name) + +### Execution Context and Environment Variables + +!!! warning "Environment Variable Visibility" + The `mcpd` process can only see environment variables that existed when it started. + + If you export new environment variables in your shell after starting `mcpd`, you must restart the daemon for those variables to become available for shell expansion. + +When the execution context file is reloaded, shell expansion of environment variables (`${VAR}` syntax) +occurs using the environment available to the running `mcpd` process when it was started. + +#### What Works During Hot Reload + +Direct values are applied immediately: + +```toml +[servers.jira] + args = ["--confluence-token=test123", "--confluence-url=http://jira-test.mozilla.ai"] +[servers.mcp-discord.env] + DISCORD_TOKEN = "qwerty123!1one" +``` + +Shell expansion of existing environment variables works: + +```toml +[servers.myserver] + args = ["--home=${HOME}", "--user=${USER}"] # These expand to current values +[servers.myserver.env] + CONFIG_PATH = "${HOME}/.config/myapp" # Expands using mcpd's environment +``` + +#### What Requires an `mcpd` Restart + +New environment variables added to the system after `mcpd` started won't be visible: + +```toml +[servers.myserver] + args = ["--token=${NEW_TOKEN}"] # NEW_TOKEN added after mcpd started +[servers.myserver.env] + API_KEY = "${NEWLY_EXPORTED_VAR}" # Won't expand until mcpd restarts +``` + +### Limitations + + +Hot reload does **NOT** apply to: + +- Daemon-level config settings (timeouts, CORS, etc.) +- New environment variables added to the system + +Both require `mcpd` to be restarted for changes to take effect + +### Error Handling + +The reload process maintains strict consistency - any error causes the daemon to exit: + +- **Configuration errors**: Invalid configuration files or loading failures cause the daemon to exit +- **Validation errors**: Invalid server configurations cause the daemon to exit +- **Server operation failures**: Any failure to start, stop, or restart a server causes the daemon to exit + +This ensures the daemon never runs in an inconsistent or partially-failed state, matching the behavior during initial startup where any server failure prevents the daemon from running. + +!!! warning "Reload Failures" + Unlike some systems that allow partial reloads, `mcpd` exits on any reload error to prevent inconsistent state. You'll need to fix the configuration and restart the daemon. + +--- + ## Configuration Export The `mcpd config export` command generates portable configuration files for deployment across different environments. It creates template variables using the naming pattern `MCPD__{SERVER_NAME}__{VARIABLE_NAME}`. @@ -83,7 +228,7 @@ Environment variables and command-line arguments are both converted to template In most cases, this is intentional, the same configuration value is being used in different ways. The collision results in a single template variable that can be used for both the environment variable and command-line argument. -#### Example collision +#### Example Collision ```toml [[servers]] diff --git a/internal/api/servers_test.go b/internal/api/servers_test.go index 2fc2449..51cbad8 100644 --- a/internal/api/servers_test.go +++ b/internal/api/servers_test.go @@ -2,6 +2,7 @@ package api import ( "context" + "fmt" "testing" "github.com/mark3labs/mcp-go/client" @@ -48,6 +49,14 @@ func (m *mockMCPClientAccessor) List() []string { return names } +func (m *mockMCPClientAccessor) UpdateTools(name string, tools []string) error { + if _, ok := m.clients[name]; !ok { + return fmt.Errorf("server '%s' not found", name) + } + m.tools[name] = tools + return nil +} + func (m *mockMCPClientAccessor) Remove(name string) { delete(m.clients, name) delete(m.tools, name) diff --git a/internal/cmd/basecmd.go b/internal/cmd/basecmd.go index 1c3e429..22ba83f 100644 --- a/internal/cmd/basecmd.go +++ b/internal/cmd/basecmd.go @@ -186,7 +186,7 @@ func FormatHandler[T any](w io.Writer, format OutputFormat, p output.Printer[T]) func (c *BaseCmd) LoadConfig(loader config.Loader) (*config.Config, error) { cfgModifier, err := loader.Load(flags.ConfigFile) if err != nil { - return nil, fmt.Errorf("failed to load config: %w", err) + return nil, err } cfg, ok := cfgModifier.(*config.Config) diff --git a/internal/config/config.go b/internal/config/config.go index 97268fa..b363cca 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -32,28 +32,33 @@ func (d *DefaultLoader) Init(path string) error { func (d *DefaultLoader) Load(path string) (Modifier, error) { path = strings.TrimSpace(path) if path == "" { - return nil, fmt.Errorf("path cannot be empty") + return nil, fmt.Errorf("%w: path cannot be empty", ErrConfigLoadFailed) } _, err := os.Stat(path) if err != nil { if os.IsNotExist(err) { - return nil, fmt.Errorf("config file cannot be found, run: 'mcpd init'") + return nil, fmt.Errorf("%w: config file cannot be found, run: 'mcpd init'", ErrConfigLoadFailed) } - return nil, fmt.Errorf("failed to stat config file (%s): %w", path, err) + return nil, fmt.Errorf("%w: failed to stat config file (%s): %w", ErrConfigLoadFailed, path, err) } var cfg *Config _, err = toml.DecodeFile(path, &cfg) if err != nil { - return nil, fmt.Errorf("failed to decode config from file (%s): %w", flags.DefaultConfigFile, err) + return nil, fmt.Errorf( + "%w: failed to decode config from file (%s): %w", + ErrConfigLoadFailed, + flags.DefaultConfigFile, + err, + ) } if cfg == nil { - return nil, fmt.Errorf("config file is empty (%s)", path) + return nil, fmt.Errorf("%w: config file is empty (%s)", ErrConfigLoadFailed, path) } if err := cfg.validate(); err != nil { - return nil, fmt.Errorf("failed to validate existing config (%s): %w", path, err) + return nil, fmt.Errorf("%w: failed to validate existing config (%s): %w", ErrConfigLoadFailed, path, err) } // Update the path that loaded this file to track it. diff --git a/internal/config/daemon_config.go b/internal/config/daemon_config.go index 4a3f75b..2340537 100644 --- a/internal/config/daemon_config.go +++ b/internal/config/daemon_config.go @@ -1260,15 +1260,6 @@ func parseBool(value string) (bool, error) { } return v, nil - - //switch strings.ToLower(strings.TrimSpace(value)) { - //case "true", "t", "1", "yes", "y": - // return true, nil - //case "false", "f", "0", "no", "n": - // return false, nil - //default: - // return false, fmt.Errorf("invalid boolean value: %s", value) - //} } // parseDuration parses a string into a Duration value. diff --git a/internal/config/errors.go b/internal/config/errors.go index a040a7d..e09c0e0 100644 --- a/internal/config/errors.go +++ b/internal/config/errors.go @@ -6,8 +6,9 @@ import ( ) var ( - ErrInvalidValue = errors.New("config value invalid") - ErrInvalidKey = errors.New("config key invalid") + ErrInvalidValue = errors.New("config value invalid") + ErrInvalidKey = errors.New("config key invalid") + ErrConfigLoadFailed = errors.New("failed to load configuration") ) // NewErrInvalidValue returns an error for an invalid configuration value. diff --git a/internal/config/types.go b/internal/config/types.go index 1f1537d..dd0e0eb 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -1,6 +1,7 @@ package config import ( + "slices" "strings" "github.com/mozilla-ai/mcpd/v2/internal/context" @@ -118,9 +119,15 @@ type serverKey struct { Package string // NOTE: without version } -func (e *ServerEntry) PackageVersion() string { +// argEntry represents a parsed command line argument. +type argEntry struct { + key string + value string +} + +func (s *ServerEntry) PackageVersion() string { versionDelim := "@" - pkg := stripPrefix(e.Package) + pkg := stripPrefix(s.Package) if idx := strings.LastIndex(pkg, versionDelim); idx != -1 { return pkg[idx+len(versionDelim):] @@ -128,19 +135,8 @@ func (e *ServerEntry) PackageVersion() string { return pkg } -func (e *ServerEntry) PackageName() string { - return stripPrefix(stripVersion(e.Package)) -} - -// argEntry represents a parsed command line argument. -type argEntry struct { - key string - value string -} - -// hasValue is used to determine if an argEntry is a bool flag or contains a value. -func (e *argEntry) hasValue() bool { - return strings.TrimSpace(e.value) != "" +func (s *ServerEntry) PackageName() string { + return stripPrefix(stripVersion(s.Package)) } func (e *argEntry) String() string { @@ -152,13 +148,102 @@ func (e *argEntry) String() string { // RequiredArguments returns all required CLI arguments, including positional, value-based and boolean flags. // NOTE: The order of these arguments matters, so positional arguments appear first. -func (e *ServerEntry) RequiredArguments() []string { - out := make([]string, 0, len(e.RequiredPositionalArgs)+len(e.RequiredValueArgs)+len(e.RequiredBoolArgs)) +func (s *ServerEntry) RequiredArguments() []string { + out := make([]string, 0, len(s.RequiredPositionalArgs)+len(s.RequiredValueArgs)+len(s.RequiredBoolArgs)) // Add positional args first. - out = append(out, e.RequiredPositionalArgs...) - out = append(out, e.RequiredValueArgs...) - out = append(out, e.RequiredBoolArgs...) + out = append(out, s.RequiredPositionalArgs...) + out = append(out, s.RequiredValueArgs...) + out = append(out, s.RequiredBoolArgs...) return out } + +// Equals compares two ServerEntry instances for equality. +// Returns true if all fields are equal. +// RequiredPositionalArgs order matters (positional), all other slices are order-independent. +func (s *ServerEntry) Equals(other *ServerEntry) bool { + if other == nil { + return false + } + + // Compare basic fields. + if s.Name != other.Name { + return false + } + + if s.Package != other.Package { + return false + } + + // RequiredPositionalArgs order matters since they're positional. + if !slices.Equal(s.RequiredPositionalArgs, other.RequiredPositionalArgs) { + return false + } + + // All other slices are flags, so order doesn't matter. + // NOTE: We are assuming that tools are always already normalized, ready for comparison. + if !equalStringSlicesUnordered(s.Tools, other.Tools) { + return false + } + + if !equalStringSlicesUnordered(s.RequiredEnvVars, other.RequiredEnvVars) { + return false + } + + if !equalStringSlicesUnordered(s.RequiredValueArgs, other.RequiredValueArgs) { + return false + } + + if !equalStringSlicesUnordered(s.RequiredBoolArgs, other.RequiredBoolArgs) { + return false + } + + return true +} + +// EqualExceptTools compares this server with another and returns true if only the Tools field differs. +// All other configuration fields must be identical for this to return true. +func (s *ServerEntry) EqualExceptTools(other *ServerEntry) bool { + if other == nil { + return false + } + + // Create copies with identical Tools to compare everything else. + a := s + b := other + + // Temporarily set tools to be identical for comparison. + bTools := b.Tools + b.Tools = a.Tools + + // If everything else is equal, then only tools differ. + equalIgnoringTools := a.Equals(b) + + // Restore original tools. + b.Tools = bTools + + // Return true only if everything else is equal AND tools actually differ. + // NOTE: We are assuming that tools are always already normalized, ready for comparison. + return equalIgnoringTools && !equalStringSlicesUnordered(s.Tools, other.Tools) +} + +// equalStringSlicesUnordered compares two string slices for equality, ignoring order. +func equalStringSlicesUnordered(a []string, b []string) bool { + if len(a) != len(b) { + return false + } + + x := slices.Clone(a) + y := slices.Clone(b) + + slices.Sort(x) + slices.Sort(y) + + return slices.Equal(x, y) +} + +// hasValue is used to determine if an argEntry is a bool flag or contains a value. +func (e *argEntry) hasValue() bool { + return strings.TrimSpace(e.value) != "" +} diff --git a/internal/config/types_test.go b/internal/config/types_test.go new file mode 100644 index 0000000..459e165 --- /dev/null +++ b/internal/config/types_test.go @@ -0,0 +1,225 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestServerEntry_Equal(t *testing.T) { + t.Parallel() + + baseEntry := func() *ServerEntry { + return &ServerEntry{ + Name: "test-server", + Package: "uvx::test-server@1.0.0", + Tools: []string{"tool1", "tool2"}, + RequiredEnvVars: []string{"API_KEY", "SECRET"}, + RequiredPositionalArgs: []string{"pos1", "pos2"}, + RequiredValueArgs: []string{"--arg1", "--arg2"}, + RequiredBoolArgs: []string{"--flag1", "--flag2"}, + } + } + + bse := baseEntry() + + testCases := []struct { + name string + entry1 *ServerEntry + entry2 *ServerEntry + expected bool + }{ + { + name: "identical entries", + entry1: bse, + entry2: bse, + expected: true, + }, + { + name: "identical content different instances", + entry1: baseEntry(), + entry2: baseEntry(), + expected: true, + }, + { + name: "nil comparison", + entry1: baseEntry(), + entry2: nil, + expected: false, + }, + { + name: "different names", + entry1: baseEntry(), + entry2: func() *ServerEntry { + srv := baseEntry() + srv.Name = "different-server" + return srv + }(), + expected: false, + }, + { + name: "different packages", + entry1: baseEntry(), + entry2: func() *ServerEntry { + srv := baseEntry() + srv.Package = "uvx::different-server@1.0.0" + return srv + }(), + expected: false, + }, + { + name: "tools order independent - same content", + entry1: baseEntry(), + entry2: func() *ServerEntry { + srv := baseEntry() + srv.Tools = []string{"tool2", "tool1"} // Different order + return srv + }(), + expected: true, + }, + { + name: "env vars order independent - same content", + entry1: baseEntry(), + entry2: func() *ServerEntry { + srv := baseEntry() + srv.RequiredEnvVars = []string{"SECRET", "API_KEY"} // Different order + return srv + }(), + expected: true, + }, + { + name: "positional args order matters", + entry1: baseEntry(), + entry2: func() *ServerEntry { + srv := baseEntry() + srv.RequiredPositionalArgs = []string{"pos2", "pos1"} // Different order + return srv + }(), + expected: false, + }, + { + name: "value args order independent - same content", + entry1: baseEntry(), + entry2: func() *ServerEntry { + srv := baseEntry() + srv.RequiredValueArgs = []string{"--arg2", "--arg1"} // Different order + return srv + }(), + expected: true, + }, + { + name: "bool args order independent - same content", + entry1: baseEntry(), + entry2: func() *ServerEntry { + srv := baseEntry() + srv.RequiredBoolArgs = []string{"--flag2", "--flag1"} // Different order + return srv + }(), + expected: true, + }, + { + name: "different tools content", + entry1: baseEntry(), + entry2: func() *ServerEntry { + srv := baseEntry() + srv.Tools = []string{"tool1", "tool3"} // Different content + return srv + }(), + expected: false, + }, + { + name: "empty slices vs nil slices are equal", + entry1: baseEntry(), + entry2: func() *ServerEntry { + srv := baseEntry() + srv.Tools = nil + return srv + }(), + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := tc.entry1.Equals(tc.entry2) + require.Equal(t, tc.expected, result) + }) + } +} + +func TestEqualStringSlicesUnordered(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + a []string + b []string + expected bool + }{ + { + name: "identical slices", + a: []string{"a", "b", "c"}, + b: []string{"a", "b", "c"}, + expected: true, + }, + { + name: "same elements different order", + a: []string{"a", "b", "c"}, + b: []string{"c", "a", "b"}, + expected: true, + }, + { + name: "different elements", + a: []string{"a", "b", "c"}, + b: []string{"a", "b", "d"}, + expected: false, + }, + { + name: "different lengths", + a: []string{"a", "b"}, + b: []string{"a", "b", "c"}, + expected: false, + }, + { + name: "empty slices", + a: []string{}, + b: []string{}, + expected: true, + }, + { + name: "nil slices", + a: nil, + b: nil, + expected: true, + }, + { + name: "empty vs nil", + a: []string{}, + b: nil, + expected: true, + }, + { + name: "duplicate elements same count", + a: []string{"a", "b", "a"}, + b: []string{"b", "a", "a"}, + expected: true, + }, + { + name: "duplicate elements different count", + a: []string{"a", "b", "a"}, + b: []string{"a", "b", "b"}, + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := equalStringSlicesUnordered(tc.a, tc.b) + require.Equal(t, tc.expected, result) + }) + } +} diff --git a/internal/contracts/mcp.go b/internal/contracts/mcp.go index f30c918..72cf5c0 100644 --- a/internal/contracts/mcp.go +++ b/internal/contracts/mcp.go @@ -18,6 +18,12 @@ type MCPHealthMonitor interface { // Update records a health check for a tracked server. Update(name string, status domain.HealthStatus, latency *time.Duration) error + + // Add registers a new server for health tracking. + Add(name string) + + // Remove stops tracking health for a server and removes its health data. + Remove(name string) } // MCPClientAccessor provides a way to interact with MCP servers through a client. @@ -33,6 +39,10 @@ type MCPClientAccessor interface { // It returns a boolean to indicate whether the tools were found. Tools(name string) ([]string, bool) + // UpdateTools updates the tools list for an existing server without restarting the client. + // Returns an error if the server is not found. + UpdateTools(name string, tools []string) error + // List returns all known server names. List() []string diff --git a/internal/daemon/api_dependencies_test.go b/internal/daemon/api_dependencies_test.go index e232cc9..5fd0a9a 100644 --- a/internal/daemon/api_dependencies_test.go +++ b/internal/daemon/api_dependencies_test.go @@ -30,6 +30,10 @@ func (m *mockClientManager) List() []string { return nil } +func (m *mockClientManager) UpdateTools(name string, tools []string) error { + return nil +} + func (m *mockClientManager) Remove(name string) { } @@ -47,6 +51,14 @@ func (m *mockHealthTracker) Update(name string, status domain.HealthStatus, late return nil } +func (m *mockHealthTracker) Add(name string) { + // Mock implementation. +} + +func (m *mockHealthTracker) Remove(name string) { + // Mock implementation. +} + func TestDaemon_APIDependencies_Validate(t *testing.T) { t.Parallel() diff --git a/internal/daemon/client_manager.go b/internal/daemon/client_manager.go index da06012..00d8f7a 100644 --- a/internal/daemon/client_manager.go +++ b/internal/daemon/client_manager.go @@ -1,6 +1,7 @@ package daemon import ( + "fmt" "sync" "github.com/mark3labs/mcp-go/client" @@ -73,6 +74,26 @@ func (cm *ClientManager) List() []string { return names } +// UpdateTools updates the tools list for an existing server without restarting the client. +// The server name and tool names are normalized for consistent lookups. +// Returns an error if the server is not found. +// This method is safe for concurrent use. +func (cm *ClientManager) UpdateTools(name string, tools []string) error { + name = filter.NormalizeString(name) + tools = filter.NormalizeSlice(tools) + cm.mu.Lock() + defer cm.mu.Unlock() + + // Check if the server exists. + if _, ok := cm.clients[name]; !ok { + return fmt.Errorf("server '%s' not found", name) + } + + // Update the tools list. + cm.serverTools[name] = tools + return nil +} + // Remove deletes the client and its tools by server name. // The server name is normalized for case-insensitive lookup. // This method is safe for concurrent use. diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 3cd902c..20d3259 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -20,6 +20,7 @@ import ( "github.com/mozilla-ai/mcpd/v2/internal/cmd" "github.com/mozilla-ai/mcpd/v2/internal/contracts" "github.com/mozilla-ai/mcpd/v2/internal/domain" + "github.com/mozilla-ai/mcpd/v2/internal/filter" "github.com/mozilla-ai/mcpd/v2/internal/runtime" ) @@ -237,8 +238,9 @@ func (d *Daemon) startMCPServer(ctx context.Context, server runtime.Server) erro packageNameAndVersion = fmt.Sprintf("%s@%s", initResult.ServerInfo.Name, initResult.ServerInfo.Version) logger.Info(fmt.Sprintf("Initialized: '%s'", packageNameAndVersion)) - // Store the client. + // Store and track the client. d.clientManager.Add(server.Name(), stdioClient, server.Tools) + d.healthTracker.Add(server.Name()) logger.Info("Ready!") @@ -437,7 +439,7 @@ func (d *Daemon) closeAllClients() { wg.Add(1) go func() { defer wg.Done() - d.closeClientWithTimeout(name, c, timeout) + _ = d.closeClientWithTimeout(name, c, timeout) // Ignore return value - leaks are acceptable during shutdown }() } @@ -445,7 +447,8 @@ func (d *Daemon) closeAllClients() { } // closeClientWithTimeout closes a single client with a timeout. -func (d *Daemon) closeClientWithTimeout(name string, c client.MCPClient, timeout time.Duration) { +// Returns true if the client closed successfully, false if it timed out. +func (d *Daemon) closeClientWithTimeout(name string, c client.MCPClient, timeout time.Duration) bool { d.logger.Info(fmt.Sprintf("Closing client %s", name)) done := make(chan struct{}) @@ -456,17 +459,186 @@ func (d *Daemon) closeClientWithTimeout(name string, c client.MCPClient, timeout // we still log the error but only for debugging purposes. d.logger.Debug("Closing client", "client", name, "error", err) } - d.logger.Info(fmt.Sprintf("Closed client %s", name)) close(done) }() // Wait for this specific client to close or timeout. - // NOTE: this could leak if we just time out clients, - // but since we're exiting mcpd it isn't an issue. select { case <-done: - // Closed successfully. + d.logger.Info(fmt.Sprintf("Closed client %s", name)) + return true case <-time.After(timeout): - d.logger.Warn(fmt.Sprintf("Timeout (%s) closing client %s", timeout.String(), name)) + d.logger.Warn( + fmt.Sprintf("Timeout (%s) closing client %s - process may still be running", timeout.String(), name), + ) + return false + } +} + +// ReloadServers reloads the daemon's MCP servers based on a new configuration. +// It compares the current servers with the new configuration and: +// - Stops servers that have been removed +// - Starts servers that have been added +// - Updates tools for servers where only tools changed +// - Restarts servers with other configuration changes +// - Preserves servers that remain unchanged (keeping their client connections, tools, and health history intact) +func (d *Daemon) ReloadServers(ctx context.Context, newServers []runtime.Server) error { + d.logger.Info("Starting server reload") + + // Validate all new servers before making any changes. + var validateErrs error + for _, srv := range newServers { + if err := srv.Validate(); err != nil { + srvErr := fmt.Errorf("invalid server configuration '%s': %w", srv.Name(), err) + validateErrs = errors.Join(validateErrs, srvErr) + } + } + if validateErrs != nil { + return fmt.Errorf("server validation failed: %w", validateErrs) + } + + existing := make(map[string]*runtime.Server) + for _, srv := range d.runtimeServers { + normalizedName := filter.NormalizeString(srv.Name()) + srvCopy := srv // Create a copy to get pointer + srv.ServerEntry.Name = normalizedName + existing[normalizedName] = &srvCopy + } + + incoming := make(map[string]*runtime.Server, len(newServers)) + for _, srv := range newServers { + normalizedName := filter.NormalizeString(srv.Name()) + srvCopy := srv // Create a copy to get pointer + srv.ServerEntry.Name = normalizedName + incoming[normalizedName] = &srvCopy + } + + // Categorize changes. + var toRemove []string + var toAdd []*runtime.Server + var toUpdateTools []*runtime.Server + var toRestart []*runtime.Server + var unchangedCount int + + // Find servers to remove (in current but not in new). + for name := range existing { + if _, exists := incoming[name]; !exists { + toRemove = append(toRemove, name) + } + } + + // Find servers to add or modify (in new). + for name, srv := range incoming { + existingSrv, exists := existing[name] + switch { + case !exists: + // New server + toAdd = append(toAdd, srv) + case existingSrv.Equals(srv): + // No changes + unchangedCount++ + case existingSrv.EqualsExceptTools(srv): + // Only tools changed + toUpdateTools = append(toUpdateTools, srv) + default: + // Other configuration changed - requires restart + toRestart = append(toRestart, srv) + } + } + + d.logger.Info("Server configuration changes", + "removed", len(toRemove), + "added", len(toAdd), + "tools_updated", len(toUpdateTools), + "restarted", len(toRestart), + "unchanged", unchangedCount) + + var errs []error + + // Stop removed servers. + for _, name := range toRemove { + if err := d.stopMCPServer(name); err != nil { + d.logger.Error("Failed to stop server", "server", name, "error", err) + errs = append(errs, fmt.Errorf("stop %s: %w", name, err)) + } + } + + // Update tools for servers with tools-only changes. + for _, srv := range toUpdateTools { + if err := d.clientManager.UpdateTools(srv.Name(), srv.Tools); err != nil { + d.logger.Error("Failed to update tools", "server", srv.Name(), "error", err) + errs = append(errs, fmt.Errorf("update-tools %s: %w", srv.Name(), err)) + } else { + d.logger.Info("Updated tools", "server", srv.Name(), "tools", srv.Tools) + } + } + + // Restart servers with configuration changes. + for _, srv := range toRestart { + d.logger.Info("Restarting server due to configuration changes", "server", srv.Name()) + + // Stop the existing server. + if err := d.stopMCPServer(srv.Name()); err != nil { + d.logger.Error("Failed to stop server for restart", "server", srv.Name(), "error", err) + errs = append(errs, fmt.Errorf("restart-stop %s: %w", srv.Name(), err)) + continue + } + + // Start the server with new configuration. + if err := d.startMCPServer(ctx, *srv); err != nil { + d.logger.Error("Failed to start server after restart", "server", srv.Name(), "error", err) + errs = append(errs, fmt.Errorf("restart-start %s: %w", srv.Name(), err)) + } + } + + // Start new servers. + for _, srv := range toAdd { + if err := d.startMCPServer(ctx, *srv); err != nil { + d.logger.Error("Failed to start new server", "server", srv.Name(), "error", err) + errs = append(errs, fmt.Errorf("add %s: %w", srv.Name(), err)) + } + } + + // Update stored runtime servers after reload (even if some operations failed). + d.runtimeServers = newServers + + if len(errs) > 0 { + d.logger.Error("Server reload completed with errors", "error_count", len(errs)) + return errors.Join(append([]error{fmt.Errorf("server reload had %d errors", len(errs))}, errs...)...) } + + d.logger.Info("Server reload completed successfully") + return nil +} + +// stopMCPServer gracefully stops a single MCP server and removes it from tracking. +func (d *Daemon) stopMCPServer(name string) error { + d.logger.Info("Stopping MCP server", "server", name) + + c, ok := d.clientManager.Client(name) + if !ok { + return fmt.Errorf("server '%s' not found", name) + } + + // Always remove from managers to maintain consistency. + d.clientManager.Remove(name) + d.healthTracker.Remove(name) + + // Close the client with timeout. + if closed := d.closeClientWithTimeout(name, c, d.clientShutdownTimeout); !closed { + d.logger.Error( + "MCP server stop timed out - process may still be running and could be leaked", + "server", name, + "timeout", d.clientShutdownTimeout, + ) + + return fmt.Errorf( + "server '%s' failed to stop within timeout %v - process may be leaked", + name, + d.clientShutdownTimeout, + ) + } + + d.logger.Info("MCP server stopped successfully", "server", name) + return nil } diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index 17b889c..307b91b 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -862,3 +862,163 @@ func TestDaemon_CloseClientWithTimeout_Direct(t *testing.T) { }) } } + +// TestDaemon_CloseClientWithTimeout_ReturnValue tests the return value behavior of closeClientWithTimeout +func TestDaemon_CloseClientWithTimeout_ReturnValue(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + clientDelay time.Duration + timeout time.Duration + expectedResult bool + }{ + { + name: "successful close returns true", + clientDelay: 50 * time.Millisecond, + timeout: 200 * time.Millisecond, + expectedResult: true, + }, + { + name: "timeout returns false", + clientDelay: 500 * time.Millisecond, + timeout: 100 * time.Millisecond, + expectedResult: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + logger := hclog.NewNullLogger() + servers := []runtime.Server{ + { + ServerEntry: config.ServerEntry{}, + ServerExecutionContext: configcontext.ServerExecutionContext{}, + }, + } + deps, err := NewDependencies(logger, ":8085", servers) + require.NoError(t, err) + daemon, err := NewDaemon(deps) + require.NoError(t, err) + + testClient := newMockMCPClientWithBehavior(tc.clientDelay, nil) + + result := daemon.closeClientWithTimeout("test-client", testClient, tc.timeout) + assert.Equal(t, tc.expectedResult, result) + }) + } +} + +// TestDaemon_StopMCPServer tests the stopMCPServer method +func TestDaemon_StopMCPServer(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + serverExists bool + clientDelay time.Duration + timeout time.Duration + expectError bool + expectedErrMsg string + }{ + { + name: "successful stop", + serverExists: true, + clientDelay: 50 * time.Millisecond, + timeout: 200 * time.Millisecond, + expectError: false, + }, + { + name: "server not found", + serverExists: false, + expectError: true, + expectedErrMsg: "server 'nonexistent' not found", + }, + { + name: "timeout during stop", + serverExists: true, + clientDelay: 500 * time.Millisecond, + timeout: 100 * time.Millisecond, + expectError: true, + expectedErrMsg: "failed to stop within timeout", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + customSink := &testLoggerSink{} + logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{ + Name: "test-daemon", + Level: hclog.Debug, + }) + logger.RegisterSink(customSink) + + clientManager := NewClientManager() + healthTracker := NewHealthTracker([]string{}) + daemon := &Daemon{ + logger: logger, + clientManager: clientManager, + healthTracker: healthTracker, + clientShutdownTimeout: tc.timeout, + } + + serverName := "test-server" + if !tc.serverExists { + serverName = "nonexistent" + } + + if tc.serverExists { + testClient := newMockMCPClientWithBehavior(tc.clientDelay, nil) + clientManager.Add("test-server", testClient, []string{"tool1"}) + healthTracker.Add("test-server") + } + + err := daemon.stopMCPServer(serverName) + + if tc.expectError { + require.Error(t, err) + if tc.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tc.expectedErrMsg) + } + } else { + require.NoError(t, err) + + // Verify server was removed from managers + _, exists := clientManager.Client("test-server") + assert.False(t, exists, "Client should be removed from manager") + + _, err := healthTracker.Status("test-server") + assert.Error(t, err, "Server should be removed from health tracker") + } + + // Verify appropriate logs + if tc.serverExists && !tc.expectError { + // Should have success log + found := false + for _, log := range customSink.messages { + if strings.Contains(log.message, "stopped successfully") { + found = true + break + } + } + assert.True(t, found, "Should have success log") + } else if tc.serverExists && tc.expectError { + // Should have error log for timeout + found := false + for _, log := range customSink.messages { + if log.level == hclog.Error && strings.Contains(log.message, "timed out") { + found = true + break + } + } + assert.True(t, found, "Should have timeout error log") + } + }) + } +} diff --git a/internal/daemon/health_tracker.go b/internal/daemon/health_tracker.go index cc6a7a8..f8dc83d 100644 --- a/internal/daemon/health_tracker.go +++ b/internal/daemon/health_tracker.go @@ -79,3 +79,26 @@ func (h *HealthTracker) Update(name string, status domain.HealthStatus, latency return nil } + +// Add registers a new server for health tracking. +// If the server is already being tracked, this is a no-op. +func (h *HealthTracker) Add(name string) { + h.mu.Lock() + defer h.mu.Unlock() + + // Only add if not already tracked. + if _, exists := h.statuses[name]; !exists { + h.statuses[name] = domain.ServerHealth{ + Name: name, + Status: domain.HealthStatusUnknown, + } + } +} + +// Remove stops tracking health for a server and removes its health data. +func (h *HealthTracker) Remove(name string) { + h.mu.Lock() + defer h.mu.Unlock() + + delete(h.statuses, name) +} diff --git a/internal/daemon/health_tracker_test.go b/internal/daemon/health_tracker_test.go index 2a5fc87..5df20fe 100644 --- a/internal/daemon/health_tracker_test.go +++ b/internal/daemon/health_tracker_test.go @@ -15,6 +15,188 @@ import ( "github.com/mozilla-ai/mcpd/v2/internal/errors" ) +func TestHealthTracker_Add(t *testing.T) { + t.Parallel() + + t.Run("add new server", func(t *testing.T) { + t.Parallel() + tracker := NewHealthTracker([]string{"existing"}) + + // Add a new server. + tracker.Add("new-server") + + // Verify it was added. + health, err := tracker.Status("new-server") + require.NoError(t, err) + require.Equal(t, "new-server", health.Name) + require.Equal(t, domain.HealthStatusUnknown, health.Status) + require.Nil(t, health.LastChecked) + require.Nil(t, health.LastSuccessful) + }) + + t.Run("add existing server (no-op)", func(t *testing.T) { + t.Parallel() + tracker := NewHealthTracker([]string{"server1"}) + + // Update the server's health. + latency := 100 * time.Millisecond + err := tracker.Update("server1", domain.HealthStatusOK, &latency) + require.NoError(t, err) + + // Get the current state. + healthBefore, err := tracker.Status("server1") + require.NoError(t, err) + + // Try to add the same server again. + tracker.Add("server1") + + // Verify the health data is preserved. + healthAfter, err := tracker.Status("server1") + require.NoError(t, err) + require.Equal(t, healthBefore, healthAfter) + require.NotNil(t, healthAfter.LastChecked) + require.NotNil(t, healthAfter.LastSuccessful) + }) + + t.Run("concurrent adds", func(t *testing.T) { + t.Parallel() + tracker := NewHealthTracker([]string{}) + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + tracker.Add(fmt.Sprintf("server-%d", i)) + }(i) + } + wg.Wait() + + // Verify all servers were added. + servers := tracker.List() + require.Len(t, servers, 10) + }) +} + +func TestHealthTracker_Remove(t *testing.T) { + t.Parallel() + + t.Run("remove existing server", func(t *testing.T) { + t.Parallel() + tracker := NewHealthTracker([]string{"server1", "server2", "server3"}) + + // Remove a server. + tracker.Remove("server2") + + // Verify it was removed. + _, err := tracker.Status("server2") + require.Error(t, err) + require.True(t, stdErrors.Is(err, errors.ErrHealthNotTracked)) + + // Verify other servers remain. + servers := tracker.List() + require.Len(t, servers, 2) + serverNames := make([]string, len(servers)) + for i, s := range servers { + serverNames[i] = s.Name + } + require.Contains(t, serverNames, "server1") + require.Contains(t, serverNames, "server3") + require.NotContains(t, serverNames, "server2") + }) + + t.Run("remove non-existent server (no-op)", func(t *testing.T) { + t.Parallel() + tracker := NewHealthTracker([]string{"server1"}) + + // Remove a non-existent server. + tracker.Remove("non-existent") + + // Verify existing server remains. + servers := tracker.List() + require.Len(t, servers, 1) + require.Equal(t, "server1", servers[0].Name) + }) + + t.Run("concurrent removes", func(t *testing.T) { + t.Parallel() + serverNames := make([]string, 20) + for i := 0; i < 20; i++ { + serverNames[i] = fmt.Sprintf("server-%d", i) + } + tracker := NewHealthTracker(serverNames) + + var wg sync.WaitGroup + // Remove even-numbered servers. + for i := 0; i < 20; i += 2 { + wg.Add(1) + go func(i int) { + defer wg.Done() + tracker.Remove(fmt.Sprintf("server-%d", i)) + }(i) + } + wg.Wait() + + // Verify only odd-numbered servers remain. + servers := tracker.List() + require.Len(t, servers, 10) + for _, s := range servers { + var num int + _, err := fmt.Sscanf(s.Name, "server-%d", &num) + require.NoError(t, err) + require.Equal(t, 1, num%2, "Expected only odd-numbered servers") + } + }) +} + +func TestHealthTracker_AddRemoveIntegration(t *testing.T) { + t.Parallel() + + tracker := NewHealthTracker([]string{"initial"}) + + // Add servers. + tracker.Add("server1") + tracker.Add("server2") + + // Update health for server1. + latency := 50 * time.Millisecond + err := tracker.Update("server1", domain.HealthStatusOK, &latency) + require.NoError(t, err) + + // Remove initial server. + tracker.Remove("initial") + + // Add another server. + tracker.Add("server3") + + // Remove server2. + tracker.Remove("server2") + + // Verify final state. + servers := tracker.List() + require.Len(t, servers, 2) + + serverMap := make(map[string]domain.ServerHealth) + for _, s := range servers { + serverMap[s.Name] = s + } + + // Verify server1 preserved its health data. + require.Contains(t, serverMap, "server1") + require.Equal(t, domain.HealthStatusOK, serverMap["server1"].Status) + require.NotNil(t, serverMap["server1"].LastChecked) + require.NotNil(t, serverMap["server1"].LastSuccessful) + + // Verify server3 is in unknown state. + require.Contains(t, serverMap, "server3") + require.Equal(t, domain.HealthStatusUnknown, serverMap["server3"].Status) + require.Nil(t, serverMap["server3"].LastChecked) + + // Verify removed servers are gone. + require.NotContains(t, serverMap, "initial") + require.NotContains(t, serverMap, "server2") +} + func TestNewHealthTracker(t *testing.T) { t.Parallel() diff --git a/internal/runtime/server.go b/internal/runtime/server.go index 959af16..e1c2212 100644 --- a/internal/runtime/server.go +++ b/internal/runtime/server.go @@ -25,6 +25,42 @@ func (s *Server) Name() string { return s.ServerEntry.Name } +// Equals compares two Server instances for complete equality. +// Returns true if both ServerEntry and ServerExecutionContext are equal. +func (s *Server) Equals(other *Server) bool { + if other == nil { + return false + } + + // Compare static configuration. + if !s.ServerEntry.Equals(&other.ServerEntry) { + return false + } + + // Compare runtime execution context. + if !s.ServerExecutionContext.Equals(other.ServerExecutionContext) { + return false + } + + return true +} + +// EqualsExceptTools compares this server with another and returns true if only the Tools field differs. +// All other configuration fields (including execution context) must be identical for this to return true. +func (s *Server) EqualsExceptTools(other *Server) bool { + if other == nil { + return false + } + + // First check if execution context is identical. + if !s.ServerExecutionContext.Equals(other.ServerExecutionContext) { + return false + } + + // Then check if only tools differ in ServerEntry. + return s.ServerEntry.EqualExceptTools(&other.ServerEntry) +} + // Runtime returns the runtime (e.g. python, node) portion of the package string. func (s *Server) Runtime() string { parts := strings.Split(s.Package, "::") diff --git a/internal/runtime/server_test.go b/internal/runtime/server_test.go index 9ac9146..a321efb 100644 --- a/internal/runtime/server_test.go +++ b/internal/runtime/server_test.go @@ -1362,3 +1362,219 @@ func TestEnvVarsToContract(t *testing.T) { }) } } + +func TestServer_Equal(t *testing.T) { + t.Parallel() + + baseServer := func() *Server { + return &Server{ + ServerEntry: config.ServerEntry{ + Name: "test-server", + Package: "uvx::test-server@1.0.0", + Tools: []string{"tool1", "tool2"}, + RequiredEnvVars: []string{"API_KEY", "SECRET"}, + RequiredPositionalArgs: []string{"pos1", "pos2"}, + RequiredValueArgs: []string{"--arg1", "--arg2"}, + RequiredBoolArgs: []string{"--flag1", "--flag2"}, + }, + ServerExecutionContext: context.ServerExecutionContext{ + Name: "test-server", + Args: []string{"--test=value"}, + Env: map[string]string{"TEST": "value"}, + }, + } + } + + testCases := []struct { + name string + server1 *Server + server2 *Server + expected bool + }{ + { + name: "identical servers", + server1: baseServer(), + server2: baseServer(), + expected: true, + }, + { + name: "nil comparison", + server1: baseServer(), + server2: nil, + expected: false, + }, + { + name: "different static config - tools", + server1: baseServer(), + server2: func() *Server { + srv := baseServer() + srv.Tools = []string{"tool1", "tool3"} + return srv + }(), + expected: false, + }, + { + name: "different static config - package", + server1: baseServer(), + server2: func() *Server { + srv := baseServer() + srv.Package = "uvx::test-server@2.0.0" + return srv + }(), + expected: false, + }, + { + name: "different execution context - args", + server1: baseServer(), + server2: func() *Server { + srv := baseServer() + srv.Args = []string{"--test=different"} + return srv + }(), + expected: false, + }, + { + name: "different execution context - env", + server1: baseServer(), + server2: func() *Server { + srv := baseServer() + srv.Env = map[string]string{"TEST": "different"} + return srv + }(), + expected: false, + }, + { + name: "different execution context - name", + server1: baseServer(), + server2: func() *Server { + srv := baseServer() + srv.ServerExecutionContext.Name = "different-name" + return srv + }(), + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := tc.server1.Equals(tc.server2) + require.Equal(t, tc.expected, result) + }) + } +} + +func TestServer_EqualExceptTools(t *testing.T) { + t.Parallel() + + baseServer := func() *Server { + return &Server{ + ServerEntry: config.ServerEntry{ + Name: "test-server", + Package: "uvx::test-server@1.0.0", + Tools: []string{"tool1", "tool2"}, + RequiredEnvVars: []string{"API_KEY", "SECRET"}, + RequiredPositionalArgs: []string{"pos1", "pos2"}, + RequiredValueArgs: []string{"--arg1", "--arg2"}, + RequiredBoolArgs: []string{"--flag1", "--flag2"}, + }, + ServerExecutionContext: context.ServerExecutionContext{ + Name: "test-server", + Args: []string{"--test=value"}, + Env: map[string]string{"TEST": "value"}, + }, + } + } + + testCases := []struct { + name string + server1 *Server + server2 *Server + expected bool + }{ + { + name: "identical servers", + server1: baseServer(), + server2: baseServer(), + expected: false, // No change at all + }, + { + name: "nil comparison", + server1: baseServer(), + server2: nil, + expected: false, + }, + { + name: "only tools changed", + server1: baseServer(), + server2: func() *Server { + srv := baseServer() + srv.Tools = []string{"tool1", "tool3"} // Different tools (tool3) + return srv + }(), + expected: true, + }, + { + name: "tools changed with different order but same content", + server1: baseServer(), + server2: func() *Server { + srv := baseServer() + srv.Tools = []string{"tool2", "tool1"} // Different order, same tools. + return srv + }(), + expected: false, // Same tools, different order = no real change + }, + { + name: "package changed", + server1: baseServer(), + server2: func() *Server { + srv := baseServer() + srv.Package = "uvx::test-server@2.0.0" // Different package version + return srv + }(), + expected: false, // Package changed = not tools-only + }, + { + name: "tools added", + server1: baseServer(), + server2: func() *Server { + srv := baseServer() + srv.Tools = []string{"tool1", "tool2", "tool3"} // Added tool3 + return srv + }(), + expected: true, + }, + { + name: "execution context args changed", + server1: baseServer(), + server2: func() *Server { + srv := baseServer() + srv.Tools = []string{"tool1", "tool3"} // Different tools + srv.Args = []string{"--test=different"} // Different args + return srv + }(), + expected: false, // Not only tools differ (args also differ) + }, + { + name: "execution context env changed", + server1: baseServer(), + server2: func() *Server { + srv := baseServer() + srv.Tools = []string{"tool1", "tool3"} // Different tools + srv.Env = map[string]string{"TEST": "different"} // Different env + return srv + }(), + expected: false, // Not only tools differ (env also differs) + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := tc.server1.EqualsExceptTools(tc.server2) + require.Equal(t, tc.expected, result) + }) + } +}