Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/add_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func TestAddCmd_RegistryFails(t *testing.T) {
cmdopts.WithRegistryBuilder(&fakeBuilder{err: errors.New("registry error")}),
)
require.NoError(t, err)

cmdObj.SetOut(io.Discard)
cmdObj.SetArgs([]string{"server1"})
err = cmdObj.Execute()
require.Error(t, err)
Expand Down
4 changes: 2 additions & 2 deletions cmd/config/daemon/validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type mockValidateConfigLoader struct {
err error
}

func (m *mockValidateConfigLoader) Load(path string) (config.Modifier, error) {
func (m *mockValidateConfigLoader) Load(_ string) (config.Modifier, error) {
if m.err != nil {
return nil, m.err
}
Expand Down Expand Up @@ -159,7 +159,7 @@ func TestValidateCmd_ConfigLoadError(t *testing.T) {

// Assertions
require.Error(t, err)
require.Contains(t, err.Error(), "failed to load config")
require.EqualError(t, err, "mock config load error")
}

func TestValidateCmd_MultipleValidationErrors(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion cmd/config/export/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (c *Cmd) run(cmd *cobra.Command, _ []string) error {
func (c *Cmd) handleExport() error {
cfg, err := c.cfgLoader.Load(flags.ConfigFile)
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
return err
}

rtCtx, err := c.ctxLoader.Load(flags.RuntimeFile)
Expand Down
4 changes: 4 additions & 0 deletions cmd/config/export/export_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package export

import (
"io"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -115,6 +116,9 @@ func TestExportCommand_Integration(t *testing.T) {
exportCmd, err := NewCmd(&cmd.BaseCmd{})
require.NoError(t, err)

// Mute the terminal output.
exportCmd.SetOut(io.Discard)

// Set command-specific flags
require.NoError(t, exportCmd.Flags().Set("context-output", contextOutput))
require.NoError(t, exportCmd.Flags().Set("contract-output", contractOutput))
Expand Down
153 changes: 113 additions & 40 deletions cmd/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func newDaemonCobraCmd(daemonCmd *DaemonCmd) *cobra.Command {

cobraCommand.MarkFlagsMutuallyExclusive("dev", flagAddr)

// Note: Additional CORS validation required to check CORS flags are present alongside --cors-enable.
// NOTE: Additional CORS validation required to check CORS flags are present alongside --cors-enable.
cobraCommand.MarkFlagsRequiredTogether(flagCORSEnable, flagCORSOrigin)

return cobraCommand
Expand Down Expand Up @@ -315,10 +315,16 @@ func (c *DaemonCmd) run(cmd *cobra.Command, _ []string) error {
return err
}

// Load the new configuration.
cfg, err := c.LoadConfig(c.cfgLoader)
if err != nil {
return fmt.Errorf("%w: %w", config.ErrConfigLoadFailed, err)
}

// Load configuration layers (config file, then flag overrides).
warnings, err := c.loadConfigurationLayers(logger, cmd)
warnings, err := c.loadConfigurationLayers(logger, cmd, cfg)
if err != nil {
return fmt.Errorf("failed to load configuration: %w", err)
return err
}

if c.dev && len(warnings) > 0 {
Expand All @@ -342,10 +348,14 @@ func (c *DaemonCmd) run(cmd *cobra.Command, _ []string) error {
return err
}

// Load runtime servers from config and context.
runtimeServers, err := c.loadRuntimeServers()
execCtx, err := c.ctxLoader.Load(flags.RuntimeFile)
if err != nil {
return fmt.Errorf("error loading runtime servers: %w", err)
return fmt.Errorf("failed to load runtime context: %w", err)
}

runtimeServers, err := runtime.AggregateConfigs(cfg, execCtx)
if err != nil {
return fmt.Errorf("failed to aggregate configs: %w", err)
}

deps, err := daemon.NewDependencies(logger, addr, runtimeServers)
Expand All @@ -368,16 +378,22 @@ func (c *DaemonCmd) run(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("failed to create mcpd daemon instance: %w", err)
}

daemonCtx, daemonCtxCancel := signal.NotifyContext(
context.Background(),
os.Interrupt,
syscall.SIGTERM, syscall.SIGINT,
)
defer daemonCtxCancel()
// Create signal contexts for shutdown and reload.
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
defer shutdownCancel()

// Setup signal handling.
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT, syscall.SIGHUP)
defer signal.Stop(sigChan)

// Create reload channel for SIGHUP handling.
reloadChan := make(chan struct{}, 1)
defer close(reloadChan) // Ensure channel is closed on function exit

runErr := make(chan error, 1)
go func() {
if err := d.StartAndManage(daemonCtx); err != nil && !errors.Is(err, context.Canceled) {
if err := d.StartAndManage(shutdownCtx); err != nil && !errors.Is(err, context.Canceled) {
runErr <- err
}
close(runErr)
Expand All @@ -387,15 +403,36 @@ func (c *DaemonCmd) run(cmd *cobra.Command, _ []string) error {
c.printDevBanner(cmd.OutOrStdout(), logger, addr)
}

select {
case <-daemonCtx.Done():
logger.Info("Shutting down daemon...")
err := <-runErr // Wait for cleanup and deferred logging
logger.Info("Shutdown complete")
return err // Graceful Ctrl+C / SIGTERM
case err := <-runErr:
logger.Error("daemon exited with error", "error", err)
return err // Propagate daemon failure
// Start signal handling in background.
go c.handleSignals(logger, sigChan, reloadChan, shutdownCancel)

// Start the daemon's main loop which responds to reloads, shutdowns and startup errors.
for {
select {
case <-reloadChan:
logger.Info("Reloading servers...")
if err := c.reloadServers(shutdownCtx, d); err != nil {
logger.Error("Failed to reload servers, exiting to prevent inconsistent state", "error", err)
// Signal shutdown to exit cleanly.
shutdownCancel()
return fmt.Errorf("configuration reload failed with unrecoverable error: %w", err)
}

logger.Info("Configuration reloaded successfully")
case <-shutdownCtx.Done():
logger.Info("Shutting down daemon...")
err := <-runErr // Wait for cleanup and deferred logging
logger.Info("Shutdown complete")

return err // Graceful shutdown
case err := <-runErr:
if err != nil {
logger.Error("daemon exited with error", "error", err)
return err // Propagate daemon failure
}

return nil
}
}
}

Expand Down Expand Up @@ -428,19 +465,17 @@ func formatValue(value any) string {
// It follows the precedence order: flags > config file > defaults.
// CLI flags override config file values when explicitly set.
// Returns warnings for each flag override and any error encountered.
func (c *DaemonCmd) loadConfigurationLayers(logger hclog.Logger, cmd *cobra.Command) ([]string, error) {
cfgModifier, err := c.cfgLoader.Load(flags.ConfigFile)
if err != nil {
return nil, err
}

cfg, ok := cfgModifier.(*config.Config)
if !ok {
return nil, fmt.Errorf("config file contains invalid configuration structure")
func (c *DaemonCmd) loadConfigurationLayers(
logger hclog.Logger,
cmd *cobra.Command,
cfg *config.Config,
) ([]string, error) {
if cfg == nil {
return nil, fmt.Errorf("config data not present, cannot apply configuration layers")
}

// No daemon config section - flags and defaults will be used.
if cfg.Daemon == nil {
// No daemon config section - flags and defaults will be used.
return nil, nil
}

Expand Down Expand Up @@ -728,24 +763,62 @@ func (c *DaemonCmd) loadConfigCORS(cors *config.CORSConfigSection, logger hclog.
return warnings
}

// loadRuntimeServers loads the configuration and aggregates it with runtime context to produce runtime servers.
func (c *DaemonCmd) loadRuntimeServers() ([]runtime.Server, error) {
cfgModifier, err := c.cfgLoader.Load(flags.ConfigFile)
// handleSignals processes OS signals for daemon lifecycle management.
// This function is intended to be called in a dedicated goroutine.
//
// SIGHUP signals trigger configuration reloads via reloadChan.
// Termination signals (SIGTERM, SIGINT, os.Interrupt) trigger graceful shutdown via shutdownCancel.
// The function runs until a shutdown signal is received or sigChan is closed.
// Non-blocking sends to reloadChan prevent duplicate reload requests.
func (c *DaemonCmd) handleSignals(
logger hclog.Logger,
sigChan <-chan os.Signal,
reloadChan chan<- struct{},
shutdownCancel context.CancelFunc,
) {
for sig := range sigChan {
switch sig {
case syscall.SIGHUP:
logger.Info("Received SIGHUP, triggering config reload")
select {
case reloadChan <- struct{}{}:
// Reload signal sent.
default:
// Reload already pending, skip.
logger.Warn("Config reload already in progress, skipping")
}
case os.Interrupt, syscall.SIGTERM, syscall.SIGINT:
logger.Info("Received shutdown signal", "signal", sig)
shutdownCancel()
return
}
}
}

// reloadServers reloads server configuration from config files.
// This method only reloads runtime servers; daemon config changes require a restart.
func (c *DaemonCmd) reloadServers(ctx context.Context, d *daemon.Daemon) error {
cfg, err := c.LoadConfig(c.cfgLoader)
if err != nil {
return nil, fmt.Errorf("failed to load config: %w", err)
return fmt.Errorf("%w: %w", config.ErrConfigLoadFailed, err)
}

execCtx, err := c.ctxLoader.Load(flags.RuntimeFile)
if err != nil {
return nil, fmt.Errorf("failed to load runtime context: %w", err)
return fmt.Errorf("failed to load runtime context: %w", err)
}

servers, err := runtime.AggregateConfigs(cfgModifier, execCtx)
newServers, err := runtime.AggregateConfigs(cfg, execCtx)
if err != nil {
return nil, fmt.Errorf("failed to aggregate configs: %w", err)
return fmt.Errorf("failed to aggregate configs: %w", err)
}

// Reload the servers in the daemon.
if err := d.ReloadServers(ctx, newServers); err != nil {
return fmt.Errorf("failed to reload servers: %w", err)
}

return servers, nil
return nil
}

// validateFlags validates the command flags and their relationships.
Expand Down
Loading