Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package config

import (
"bytes"
"context"
"fmt"
"os"
"path/filepath"

"github.com/BurntSushi/toml"
)
Expand Down Expand Up @@ -68,6 +70,9 @@ type StaticConfig struct {

// Internal: parsed provider configs (not exposed to TOML package)
parsedClusterProviderConfigs map[string]ProviderConfig

// Internal: the config.toml directory, to help resolve relative file paths
configDirPath string
}

type GroupVersionKind struct {
Expand All @@ -76,23 +81,48 @@ type GroupVersionKind struct {
Kind string `toml:"kind,omitempty"`
}

// Read reads the toml file and returns the StaticConfig.
func Read(configPath string) (*StaticConfig, error) {
type ReadConfigOpt func(cfg *StaticConfig)

func withDirPath(path string) ReadConfigOpt {
return func(cfg *StaticConfig) {
cfg.configDirPath = path
}
}

// Read reads the toml file and returns the StaticConfig, with any opts applied.
func Read(configPath string, opts ...ReadConfigOpt) (*StaticConfig, error) {
configData, err := os.ReadFile(configPath)
if err != nil {
return nil, err
}
return ReadToml(configData)

// get and save the absolute dir path to the config file, so that other config parsers can use it
absPath, err := filepath.Abs(configPath)
if err != nil {
return nil, fmt.Errorf("failed to resolve absolute path to config file: %w", err)
}
dirPath := filepath.Dir(absPath)

cfg, err := ReadToml(configData, append(opts, withDirPath(dirPath))...)
if err != nil {
return nil, err
}

return cfg, nil
}

// ReadToml reads the toml data and returns the StaticConfig.
func ReadToml(configData []byte) (*StaticConfig, error) {
// ReadToml reads the toml data and returns the StaticConfig, with any opts applied
func ReadToml(configData []byte, opts ...ReadConfigOpt) (*StaticConfig, error) {
config := Default()
md, err := toml.NewDecoder(bytes.NewReader(configData)).Decode(config)
if err != nil {
return nil, err
}

for _, opt := range opts {
opt(config)
}

if err := config.parseClusterProviderConfigs(md); err != nil {
return nil, err
}
Expand All @@ -111,13 +141,15 @@ func (c *StaticConfig) parseClusterProviderConfigs(md toml.MetaData) error {
c.parsedClusterProviderConfigs = make(map[string]ProviderConfig, len(c.ClusterProviderConfigs))
}

ctx := withConfigDirPath(context.Background(), c.configDirPath)

for strategy, primitive := range c.ClusterProviderConfigs {
parser, ok := getProviderConfigParser(strategy)
if !ok {
continue
}

providerConfig, err := parser(primitive, md)
providerConfig, err := parser(ctx, primitive, md)
if err != nil {
return fmt.Errorf("failed to parse config for ClusterProvider '%s': %w", strategy, err)
}
Expand Down
23 changes: 22 additions & 1 deletion pkg/config/provider_config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package config

import (
"context"
"fmt"

"github.com/BurntSushi/toml"
Expand All @@ -12,7 +13,27 @@ type ProviderConfig interface {
Validate() error
}

type ProviderConfigParser func(primitive toml.Primitive, md toml.MetaData) (ProviderConfig, error)
type ProviderConfigParser func(ctx context.Context, primitive toml.Primitive, md toml.MetaData) (ProviderConfig, error)

type configDirPathKey struct{}

func withConfigDirPath(ctx context.Context, dirPath string) context.Context {
return context.WithValue(ctx, configDirPathKey{}, dirPath)
}

func ConfigDirPathFromContext(ctx context.Context) string {
val := ctx.Value(configDirPathKey{})

if val == nil {
return ""
}

if strVal, ok := val.(string); ok {
return strVal
}

return ""
}

var (
providerConfigParsers = make(map[string]ProviderConfigParser)
Expand Down
35 changes: 33 additions & 2 deletions pkg/config/provider_config_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package config

import (
"context"
"errors"
"path/filepath"
"testing"

"github.com/BurntSushi/toml"
Expand Down Expand Up @@ -42,7 +44,7 @@ func (p *ProviderConfigForTest) Validate() error {
return nil
}

func providerConfigForTestParser(primitive toml.Primitive, md toml.MetaData) (ProviderConfig, error) {
func providerConfigForTestParser(ctx context.Context, primitive toml.Primitive, md toml.MetaData) (ProviderConfig, error) {
var providerConfigForTest ProviderConfigForTest
if err := md.PrimitiveDecode(primitive, &providerConfigForTest); err != nil {
return nil, err
Expand Down Expand Up @@ -131,7 +133,7 @@ func (s *ProviderConfigSuite) TestReadConfigUnregisteredProviderConfig() {
}

func (s *ProviderConfigSuite) TestReadConfigParserError() {
RegisterProviderConfig("test", func(primitive toml.Primitive, md toml.MetaData) (ProviderConfig, error) {
RegisterProviderConfig("test", func(ctx context.Context, primitive toml.Primitive, md toml.MetaData) (ProviderConfig, error) {
return nil, errors.New("parser error forced by test")
})
invalidConfigPath := s.writeConfig(`
Expand All @@ -152,6 +154,35 @@ func (s *ProviderConfigSuite) TestReadConfigParserError() {
})
}

func (s *ProviderConfigSuite) TestConfigDirPathInContext() {
var capturedDirPath string
RegisterProviderConfig("test", func(ctx context.Context, primitive toml.Primitive, md toml.MetaData) (ProviderConfig, error) {
capturedDirPath = ConfigDirPathFromContext(ctx)
var providerConfigForTest ProviderConfigForTest
if err := md.PrimitiveDecode(primitive, &providerConfigForTest); err != nil {
return nil, err
}
return &providerConfigForTest, nil
})
configPath := s.writeConfig(`
cluster_provider_strategy = "test"
[cluster_provider_configs.test]
bool_prop = true
str_prop = "a string"
int_prop = 42
`)

absConfigPath, err := filepath.Abs(configPath)
s.Require().NoError(err, "test error: getting the absConfigPath should not fail")

_, err = Read(configPath)
s.Run("provides config directory path in context to parser", func() {
s.Require().NoError(err, "Expected no error reading config")
s.NotEmpty(capturedDirPath, "Expected non-empty directory path in context")
s.Equal(filepath.Dir(absConfigPath), capturedDirPath, "Expected directory path to match config file directory")
})
}

func TestProviderConfig(t *testing.T) {
suite.Run(t, new(ProviderConfigSuite))
}
Loading