diff --git a/pkg/config/config.go b/pkg/config/config.go index 5bd00ff3..81bec2b7 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -2,8 +2,10 @@ package config import ( "bytes" + "context" "fmt" "os" + "path/filepath" "github.com/BurntSushi/toml" ) @@ -68,6 +70,9 @@ type StaticConfig struct { // Internal: parsed provider configs (not exposed to TOML package) parsedClusterProviderConfigs map[string]ProviderConfig + + // Internal: the config.toml directory, to help resolve relative file paths + configDirPath string } type GroupVersionKind struct { @@ -76,23 +81,48 @@ type GroupVersionKind struct { Kind string `toml:"kind,omitempty"` } -// Read reads the toml file and returns the StaticConfig. -func Read(configPath string) (*StaticConfig, error) { +type ReadConfigOpt func(cfg *StaticConfig) + +func withDirPath(path string) ReadConfigOpt { + return func(cfg *StaticConfig) { + cfg.configDirPath = path + } +} + +// Read reads the toml file and returns the StaticConfig, with any opts applied. +func Read(configPath string, opts ...ReadConfigOpt) (*StaticConfig, error) { configData, err := os.ReadFile(configPath) if err != nil { return nil, err } - return ReadToml(configData) + + // get and save the absolute dir path to the config file, so that other config parsers can use it + absPath, err := filepath.Abs(configPath) + if err != nil { + return nil, fmt.Errorf("failed to resolve absolute path to config file: %w", err) + } + dirPath := filepath.Dir(absPath) + + cfg, err := ReadToml(configData, append(opts, withDirPath(dirPath))...) + if err != nil { + return nil, err + } + + return cfg, nil } -// ReadToml reads the toml data and returns the StaticConfig. -func ReadToml(configData []byte) (*StaticConfig, error) { +// ReadToml reads the toml data and returns the StaticConfig, with any opts applied +func ReadToml(configData []byte, opts ...ReadConfigOpt) (*StaticConfig, error) { config := Default() md, err := toml.NewDecoder(bytes.NewReader(configData)).Decode(config) if err != nil { return nil, err } + for _, opt := range opts { + opt(config) + } + if err := config.parseClusterProviderConfigs(md); err != nil { return nil, err } @@ -111,13 +141,15 @@ func (c *StaticConfig) parseClusterProviderConfigs(md toml.MetaData) error { c.parsedClusterProviderConfigs = make(map[string]ProviderConfig, len(c.ClusterProviderConfigs)) } + ctx := withConfigDirPath(context.Background(), c.configDirPath) + for strategy, primitive := range c.ClusterProviderConfigs { parser, ok := getProviderConfigParser(strategy) if !ok { continue } - providerConfig, err := parser(primitive, md) + providerConfig, err := parser(ctx, primitive, md) if err != nil { return fmt.Errorf("failed to parse config for ClusterProvider '%s': %w", strategy, err) } diff --git a/pkg/config/provider_config.go b/pkg/config/provider_config.go index 23c5fffe..45dd2f8d 100644 --- a/pkg/config/provider_config.go +++ b/pkg/config/provider_config.go @@ -1,6 +1,7 @@ package config import ( + "context" "fmt" "github.com/BurntSushi/toml" @@ -12,7 +13,27 @@ type ProviderConfig interface { Validate() error } -type ProviderConfigParser func(primitive toml.Primitive, md toml.MetaData) (ProviderConfig, error) +type ProviderConfigParser func(ctx context.Context, primitive toml.Primitive, md toml.MetaData) (ProviderConfig, error) + +type configDirPathKey struct{} + +func withConfigDirPath(ctx context.Context, dirPath string) context.Context { + return context.WithValue(ctx, configDirPathKey{}, dirPath) +} + +func ConfigDirPathFromContext(ctx context.Context) string { + val := ctx.Value(configDirPathKey{}) + + if val == nil { + return "" + } + + if strVal, ok := val.(string); ok { + return strVal + } + + return "" +} var ( providerConfigParsers = make(map[string]ProviderConfigParser) diff --git a/pkg/config/provider_config_test.go b/pkg/config/provider_config_test.go index d933d894..84902da4 100644 --- a/pkg/config/provider_config_test.go +++ b/pkg/config/provider_config_test.go @@ -1,7 +1,9 @@ package config import ( + "context" "errors" + "path/filepath" "testing" "github.com/BurntSushi/toml" @@ -42,7 +44,7 @@ func (p *ProviderConfigForTest) Validate() error { return nil } -func providerConfigForTestParser(primitive toml.Primitive, md toml.MetaData) (ProviderConfig, error) { +func providerConfigForTestParser(ctx context.Context, primitive toml.Primitive, md toml.MetaData) (ProviderConfig, error) { var providerConfigForTest ProviderConfigForTest if err := md.PrimitiveDecode(primitive, &providerConfigForTest); err != nil { return nil, err @@ -131,7 +133,7 @@ func (s *ProviderConfigSuite) TestReadConfigUnregisteredProviderConfig() { } func (s *ProviderConfigSuite) TestReadConfigParserError() { - RegisterProviderConfig("test", func(primitive toml.Primitive, md toml.MetaData) (ProviderConfig, error) { + RegisterProviderConfig("test", func(ctx context.Context, primitive toml.Primitive, md toml.MetaData) (ProviderConfig, error) { return nil, errors.New("parser error forced by test") }) invalidConfigPath := s.writeConfig(` @@ -152,6 +154,35 @@ func (s *ProviderConfigSuite) TestReadConfigParserError() { }) } +func (s *ProviderConfigSuite) TestConfigDirPathInContext() { + var capturedDirPath string + RegisterProviderConfig("test", func(ctx context.Context, primitive toml.Primitive, md toml.MetaData) (ProviderConfig, error) { + capturedDirPath = ConfigDirPathFromContext(ctx) + var providerConfigForTest ProviderConfigForTest + if err := md.PrimitiveDecode(primitive, &providerConfigForTest); err != nil { + return nil, err + } + return &providerConfigForTest, nil + }) + configPath := s.writeConfig(` + cluster_provider_strategy = "test" + [cluster_provider_configs.test] + bool_prop = true + str_prop = "a string" + int_prop = 42 + `) + + absConfigPath, err := filepath.Abs(configPath) + s.Require().NoError(err, "test error: getting the absConfigPath should not fail") + + _, err = Read(configPath) + s.Run("provides config directory path in context to parser", func() { + s.Require().NoError(err, "Expected no error reading config") + s.NotEmpty(capturedDirPath, "Expected non-empty directory path in context") + s.Equal(filepath.Dir(absConfigPath), capturedDirPath, "Expected directory path to match config file directory") + }) +} + func TestProviderConfig(t *testing.T) { suite.Run(t, new(ProviderConfigSuite)) }