Skip to content

Commit 089b5b5

Browse files
committed
Refactor config handling to improve flag validation and path resolution
1 parent a801d5c commit 089b5b5

File tree

2 files changed

+92
-28
lines changed

2 files changed

+92
-28
lines changed

cmd/keymaster/main.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,24 @@ func applyDefaultFlags(cmd *cobra.Command) {
120120

121121
func getConfigPathFromCli(cmd *cobra.Command) (*string, error) {
122122
// Load optional config file argument from cli
123-
if path, err := cmd.PersistentFlags().GetString("config"); err == nil {
124-
// make sure the user provided file exists, to mitigate unwanted behaivio, like loading unwanted default configs
125-
if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
126-
return nil, err
123+
// Only proceed if the user has explicitly set the --config flag.
124+
if cmd.PersistentFlags().Changed("config") {
125+
path, err := cmd.PersistentFlags().GetString("config")
126+
if err != nil {
127+
// This is unlikely if Changed() is true, but good practice.
128+
return nil, fmt.Errorf("could not read --config flag: %w", err)
129+
}
130+
131+
// If the flag is set but the value is empty, do nothing.
132+
if path == "" {
133+
return nil, nil
134+
}
135+
136+
// Make sure the user-provided file exists to avoid unwanted behavior.
137+
if _, err := os.Stat(path); err != nil {
138+
return nil, fmt.Errorf("config file specified via --config flag not found or is not accessible: %w", err)
127139
}
128-
return &path, nil
140+
return &path, nil // Return the valid path
129141
}
130142
return nil, nil
131143
}

internal/config/config.go

Lines changed: 75 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,92 @@
11
package config
22

33
import (
4+
"fmt"
45
"os"
6+
"path/filepath"
7+
"runtime"
58
"strings"
69

710
"github.com/goccy/go-yaml"
811
"github.com/spf13/cobra"
912
"github.com/spf13/viper"
1013
)
1114

12-
func configPath(system bool) (string, error) {
15+
// getConfigPath returns the full path for the configuration file.
16+
func getConfigPath(system bool) (string, error) {
17+
var configDir string
18+
var err error
19+
1320
if system {
14-
// TODO make os aware
15-
return "/etc/keymaster", nil
21+
// System-wide configuration paths
22+
switch runtime.GOOS {
23+
case "windows":
24+
configDir = filepath.Join(os.Getenv("ProgramData"), "Keymaster")
25+
default: // Linux, macOS, etc.
26+
configDir = "/etc/keymaster"
27+
}
1628
} else {
17-
return os.UserConfigDir()
29+
// User-specific configuration paths
30+
configDir, err = os.UserConfigDir()
31+
if err != nil {
32+
return "", fmt.Errorf("could not get user config directory: %w", err)
33+
}
34+
configDir = filepath.Join(configDir, "keymaster")
1835
}
36+
37+
return filepath.Join(configDir, "keymaster.yaml"), nil
1938
}
2039

2140
func LoadConfig[T any](cmd *cobra.Command, defaults map[string]any, additional_config_file_path *string) (T, error) {
2241
var c T
2342
v := viper.New()
2443

25-
// defaults
44+
// 1. Set defaults
2645
for key, value := range defaults {
2746
v.SetDefault(key, value)
2847
}
2948

30-
// files (first file found wins)
49+
// 2. Set up file search paths (new format: keymaster.yaml)
3150
v.SetConfigName("keymaster")
3251
v.SetConfigType("yaml")
52+
53+
// 3. Add explicit config file path if provided via --config flag.
54+
// This has the highest precedence for file-based configuration.
3355
if additional_config_file_path != nil {
34-
v.AddConfigPath(*additional_config_file_path)
56+
v.SetConfigFile(*additional_config_file_path)
3557
}
36-
if home, err := os.UserHomeDir(); err == nil {
37-
v.AddConfigPath(home + "/.config")
58+
59+
// 3. Add standard config locations
60+
if userConfigPath, err := getConfigPath(false); err == nil {
61+
v.AddConfigPath(filepath.Dir(userConfigPath))
3862
}
39-
v.AddConfigPath("/etc/keymaster")
63+
if systemConfigPath, err := getConfigPath(true); err == nil {
64+
v.AddConfigPath(filepath.Dir(systemConfigPath))
65+
}
66+
v.AddConfigPath(".") // Look for keymaster.yaml in current dir
67+
68+
// 5. Read in the primary config file.
69+
if err := v.ReadInConfig(); err != nil {
70+
// It's okay if the file is not found, but other errors are fatal.
71+
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
72+
return c, err
73+
}
74+
}
75+
76+
// 6. For backward compatibility, check for and merge `.keymaster.yaml` in the current directory.
77+
mergeLegacyConfig(v)
4078

41-
// env
79+
// 7. Read from environment variables
4280
v.AutomaticEnv()
4381
v.AllowEmptyEnv(true)
4482
v.SetEnvPrefix("keymaster")
4583
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
4684

4785
// cli
4886
// TODO maybe needs to trigger additional parsing beferohand (most likely nots)
49-
v.BindPFlags(cmd.Flags())
50-
51-
// TODO maybe not needed
52-
// if err := v.ReadInConfig(); err != nil {
53-
// if _, ok := err.(v.ConfigFileNotFoundError); !ok {
54-
// return nil, err
55-
// }
56-
// }
87+
if err := v.BindPFlags(cmd.Flags()); err != nil {
88+
return c, err
89+
}
5790

5891
// parse config
5992
if err := v.Unmarshal(&c); err != nil {
@@ -63,8 +96,24 @@ func LoadConfig[T any](cmd *cobra.Command, defaults map[string]any, additional_c
6396
return c, nil
6497
}
6598

99+
// mergeLegacyConfig checks for a `.keymaster.yaml` file in the current directory
100+
// and merges it into the viper configuration if found. This is for backward compatibility.
101+
func mergeLegacyConfig(v *viper.Viper) {
102+
legacyConfigFile := ".keymaster.yaml"
103+
if _, err := os.Stat(legacyConfigFile); err == nil {
104+
// File exists, let's try to merge it.
105+
v.SetConfigFile(legacyConfigFile)
106+
// MergeInConfig will not error on file not found, but we already checked.
107+
// It will error on a malformed file, which is the desired behavior.
108+
// We can ignore the error for this compatibility layer to avoid breaking startup.
109+
_ = v.MergeInConfig()
110+
// Reset the config file path to avoid side effects.
111+
v.SetConfigFile("")
112+
}
113+
}
114+
66115
func WriteConfigFile[T any](c *T, system bool) error {
67-
path, err := configPath(system)
116+
path, err := getConfigPath(system)
68117
if err != nil {
69118
return err
70119
}
@@ -74,10 +123,13 @@ func WriteConfigFile[T any](c *T, system bool) error {
74123
return err
75124
}
76125

77-
// TODO recursively create directory if not present
126+
// Create directory if it doesn't exist
127+
configDir := filepath.Dir(path)
128+
if err := os.MkdirAll(configDir, 0755); err != nil {
129+
return fmt.Errorf("could not create config directory %s: %w", configDir, err)
130+
}
78131

79-
// TODO review permissions, because database secrets may be saved here in user mode (user restricted read permissions when not in system mode)
80-
err = os.WriteFile(path, data, 0644)
132+
err = os.WriteFile(path, data, 0600) // Use 0600 for security, as it may contain secrets
81133
if err != nil {
82134
return err
83135
}

0 commit comments

Comments
 (0)