Skip to content

Commit c80bcb5

Browse files
authored
Refactor config into two-phase processing with shared flags (#9)
* refactor: split config processing into PreProcess and Process phases * make lint and make test
1 parent d5cd862 commit c80bcb5

File tree

18 files changed

+829
-497
lines changed

18 files changed

+829
-497
lines changed

internal/api/dashboard/config.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@ package dashboard
33
import (
44
"net/http"
55

6+
"github.com/infracost/cli/internal/config/process"
67
"github.com/infracost/cli/pkg/environment"
78
)
89

910
var (
10-
defaultValues = map[environment.Environment]map[string]string{
11+
_ process.Processor = (*Config)(nil)
12+
13+
defaultValues = map[string]map[string]string{
1114
environment.Production: {
1215
"endpoint": "https://dashboard.api.infracost.io",
1316
},
@@ -21,7 +24,14 @@ var (
2124
)
2225

2326
type Config struct {
24-
Endpoint string `env:"INFRACOST_CLI_DASHBOARD_ENDPOINT" flag:"dashboard-endpoint;hidden" usage:"The endpoint for the Infracost dashboard"`
27+
Environment string `flagvalue:"environment"`
28+
Endpoint string `env:"INFRACOST_CLI_DASHBOARD_ENDPOINT" flag:"dashboard-endpoint;hidden" usage:"The endpoint for the Infracost dashboard"`
29+
}
30+
31+
func (c *Config) Process() {
32+
if c.Endpoint == "" {
33+
c.Endpoint = defaultValues[c.Environment]["endpoint"]
34+
}
2535
}
2636

2737
func (c *Config) Client(client *http.Client) *Client {
@@ -30,9 +40,3 @@ func (c *Config) Client(client *http.Client) *Client {
3040
config: c,
3141
}
3242
}
33-
34-
func (c *Config) ApplyDefaults(env environment.Environment) {
35-
if c.Endpoint == "" {
36-
c.Endpoint = defaultValues[env]["endpoint"]
37-
}
38-
}

internal/cache/cache_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func testOutput() format.Output {
5151
func testConfig(t *testing.T) *Config {
5252
t.Helper()
5353
c := &Config{Cache: filepath.Join(t.TempDir(), "cache")}
54-
c.ApplyDefaults()
54+
c.Process()
5555
return c
5656
}
5757

internal/cache/config.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,16 @@ import (
77
"path/filepath"
88
"time"
99

10+
"github.com/infracost/cli/internal/api/events"
11+
config "github.com/infracost/cli/internal/config/process"
1012
"github.com/infracost/cli/internal/logging"
1113
"github.com/shirou/gopsutil/process"
1214
)
1315

16+
var (
17+
_ config.Processor = (*Config)(nil)
18+
)
19+
1420
type Config struct {
1521
// Cache is where the cache files should go.
1622
Cache string `env:"INFRACOST_CLI_CACHE_DIRECTORY"`
@@ -23,7 +29,7 @@ type Config struct {
2329
SessionID string `env:"INFRACOST_SESSION_ID"`
2430
}
2531

26-
func (c *Config) ApplyDefaults() {
32+
func (c *Config) Process() {
2733
if len(c.Cache) == 0 {
2834
c.Cache = defaultCachePath()
2935
}
@@ -33,6 +39,7 @@ func (c *Config) ApplyDefaults() {
3339
if c.SessionID == "" {
3440
c.SessionID = getSessionID()
3541
}
42+
events.RegisterMetadata("session", c.SessionID)
3643
}
3744

3845
func defaultCachePath() string {

internal/config/config.go

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@ import (
66
"github.com/infracost/cli/internal/api/dashboard"
77
"github.com/infracost/cli/internal/api/events"
88
"github.com/infracost/cli/internal/cache"
9+
"github.com/infracost/cli/internal/config/process"
910
"github.com/infracost/cli/internal/logging"
1011
"github.com/infracost/cli/pkg/auth"
1112
"github.com/infracost/cli/pkg/environment"
1213
"github.com/infracost/cli/pkg/plugins"
13-
"github.com/spf13/cobra"
14-
"github.com/spf13/pflag"
14+
)
15+
16+
var (
17+
_ process.Processor = (*Config)(nil)
1518
)
1619

1720
// Config contains the configuration for the CLI.
@@ -51,18 +54,9 @@ type Config struct {
5154
Cache cache.Config
5255
}
5356

54-
func (config *Config) RegisterEventMetadata(cmd *cobra.Command) {
55-
events.RegisterMetadata("command", cmd.Name())
56-
events.RegisterMetadata("flags", func() []string {
57-
var flags []string
58-
cmd.Flags().Visit(func(flag *pflag.Flag) {
59-
flags = append(flags, flag.Name)
60-
})
61-
return flags
62-
}())
63-
events.RegisterMetadata("session", config.Cache.SessionID)
57+
func (config *Config) Process() {
6458
events.RegisterMetadata("cloudEnabled", os.Getenv("INFRACOST_ENABLE_CLOUD") == "true")
6559
events.RegisterMetadata("dashboardEnabled", os.Getenv("INFRACOST_ENABLE_DASHBOARD") == "true")
66-
events.RegisterMetadata("environment", string(config.Environment))
60+
events.RegisterMetadata("environment", config.Environment.String())
6761
events.RegisterMetadata("isDefaultPricingApiEndpoint", config.PricingEndpoint == "https://pricing.api.infracost.io")
6862
}

internal/config/config_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package config
2+
3+
import (
4+
"testing"
5+
6+
"github.com/infracost/cli/internal/config/process"
7+
"github.com/spf13/pflag"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestConfig_Process(t *testing.T) {
12+
var cfg Config
13+
14+
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
15+
16+
// first, make sure that preprocess doesn't error or panic when no values provided.
17+
if diags := process.PreProcess(&cfg, flags); diags.Len() != 0 {
18+
t.Fatal(diags)
19+
}
20+
require.NoError(t, flags.Parse(nil)) // we have no required flags yet, so will provide nothing
21+
process.Process(&cfg) // make sure doesn't panic
22+
23+
// environment is a shared flag, so let's make sure that all worked
24+
require.Equal(t, "prod", cfg.Environment.String())
25+
require.Equal(t, "prod", cfg.Auth.Environment)
26+
require.Equal(t, "prod", cfg.Dashboard.Environment)
27+
}
Lines changed: 46 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package config
1+
package process
22

33
import (
44
"fmt"
@@ -17,50 +17,32 @@ var (
1717
valueType = reflect.TypeOf((*pflag.Value)(nil)).Elem()
1818
)
1919

20-
// Process populates a struct's fields from environment variables and command-line flags.
21-
//
22-
// The target must be a pointer to a struct. Each field can be tagged with `env`, `flag`, `usage`,
23-
// and `default` to specify how the field should be populated.
24-
//
25-
// If no tag is specified, the field name is used as the basis for both the environment variable
26-
// (converted to SCREAMING_SNAKE_CASE) and the flag name (converted to kebab-case). This behavior
27-
// can be disabled by setting the `env` or `flag` tag to `-`.
28-
//
29-
// Example:
30-
//
31-
// type MyConfig struct {
32-
// MyField string `env:"MY_FIELD" flag:"my-field;hidden" usage:"my field usage" default:"my-default"`
33-
// }
34-
//
35-
// Flags can be marked as hidden by appending `;hidden` to the `flag` tag.
36-
//
37-
// Process supports nested structs, and basic types such as strings, integers, and booleans.
38-
// Additionally, any type implementing the pflag.Value interface is supported.
39-
func Process(target interface{}, flags *pflag.FlagSet) *diagnostic.Diagnostics {
20+
// PreProcess walks the target struct and hydrates fields based on struct tags.
21+
// Fields tagged with `env:"VAR"` are populated from environment variables,
22+
// `default:"val"` sets a fallback when no env var is found and the field is zero,
23+
// and `flag:"name"` registers a pflag on the provided FlagSet. Nested structs
24+
// are recursed into unless they implement pflag.Value, in which case they are
25+
// treated as leaf values. The target must be a pointer to a struct.
26+
func PreProcess(target interface{}, flags *pflag.FlagSet) *diagnostic.Diagnostics {
4027
v := reflect.ValueOf(target)
4128

4229
if v.Kind() == reflect.Interface {
43-
// we'll support interfaces, but we'll only allow one level of indirection
44-
// basically, just allows people to pass in things that have been assigned to a interface{} or any type
45-
// constraint
4630
v = v.Elem()
4731
}
4832

4933
if v.Kind() != reflect.Pointer {
50-
// but, we must have a pointer to a struct at the next level
5134
panic("target must be a pointer to a struct")
5235
}
53-
v = v.Elem() // unpack the pointer
36+
v = v.Elem()
5437

5538
if v.Kind() != reflect.Struct {
56-
// we must now actually have the struct we're going to be working on
5739
panic("target must be a pointer to a struct")
5840
}
5941

60-
return processStruct(v, flags)
42+
return preprocess(v, flags)
6143
}
6244

63-
func processStruct(v reflect.Value, flags *pflag.FlagSet) *diagnostic.Diagnostics {
45+
func preprocess(v reflect.Value, flags *pflag.FlagSet) *diagnostic.Diagnostics {
6446
var diags *diagnostic.Diagnostics
6547

6648
t := v.Type()
@@ -74,31 +56,37 @@ func processStruct(v reflect.Value, flags *pflag.FlagSet) *diagnostic.Diagnostic
7456

7557
envName, hasEnvName := field.Tag.Lookup("env")
7658
flagValue, hasFlagValue := field.Tag.Lookup("flag")
59+
flagTargetName, hasFlagTarget := field.Tag.Lookup("flagvalue")
7760
defaultValue, hasDefaultValue := field.Tag.Lookup("default")
7861

7962
hasEnvName = hasEnvName && envName != ""
8063
hasFlagValue = hasFlagValue && flagValue != ""
64+
hasFlagTarget = hasFlagTarget && flagTargetName != ""
8165
hasDefaultValue = hasDefaultValue && defaultValue != ""
8266

8367
currentType, parentType := unpackType(fieldValue.Type(), fieldValue.Addr().Type())
8468
isPflagValue := parentType.Implements(valueType)
8569

8670
if currentType.Kind() == reflect.Struct && !isPflagValue {
87-
if hasEnvName || hasFlagValue || hasDefaultValue {
71+
if hasEnvName || hasFlagValue || hasDefaultValue || hasFlagTarget {
8872
// programmer error, so we panic
89-
panic("nested structs cannot be tagged with env, flag or default, or they must implement pflag.Value")
73+
panic("nested structs cannot be tagged with env, flag, flagvalue or default, or they must implement pflag.Value")
9074
}
9175

9276
current, _ := unpackValue(fieldValue)
9377

9478
// Then we have a struct that needs to be processed recursively.
95-
if err := processStruct(current, flags); err != nil {
79+
if err := preprocess(current, flags); err != nil {
9680
return err
9781
}
9882
}
9983

100-
if !hasEnvName && !hasFlagValue && !hasDefaultValue {
101-
// if we have no env, flag or default value then we're not going to touch this field
84+
if hasFlagTarget && (hasFlagValue || hasEnvName || hasDefaultValue) {
85+
panic("flagvalue cannot be combined with flag, env, or default tags")
86+
}
87+
88+
if !hasEnvName && !hasFlagValue && !hasFlagTarget && !hasDefaultValue {
89+
// if we have no env, flag, flagvalue or default value then we're not going to touch this field
10290
continue
10391
}
10492

@@ -142,11 +130,35 @@ func processStruct(v reflect.Value, flags *pflag.FlagSet) *diagnostic.Diagnostic
142130
registerFlag(parent, flags, flagName, field.Tag.Get("usage"), hidden, isPflagValue)
143131
}
144132

133+
if hasFlagTarget {
134+
existing := flags.Lookup(flagTargetName)
135+
if existing == nil {
136+
panic(fmt.Sprintf("flagvalue %q references flag that has not been registered", flagTargetName))
137+
}
138+
139+
sf, ok := existing.Value.(SharedFlag)
140+
if !ok {
141+
panic(fmt.Sprintf("flagvalue %q references flag that does not implement SharedFlag", flagTargetName))
142+
}
143+
144+
if fieldValue.Kind() != reflect.String {
145+
panic(fmt.Sprintf("flagvalue %q can only be used on string fields", flagTargetName))
146+
}
147+
148+
target := fieldValue.Addr().Interface().(*string)
149+
if existing.DefValue != "" {
150+
*target = existing.DefValue
151+
}
152+
sf.AddTarget(target)
153+
}
154+
145155
}
146156

147157
return diags
148158
}
149159

160+
// setFieldValue sets a reflected field's value from a string. It handles
161+
// pflag.Value types, time.Duration, and primitive kinds (string, int, bool).
150162
func setFieldValue(v reflect.Value, s string, isPflagValue bool) error {
151163
if isPflagValue {
152164
pv := v.Addr().Interface().(pflag.Value)
@@ -243,30 +255,3 @@ func registerFlag(v reflect.Value, flags *pflag.FlagSet, name string, usage stri
243255
}
244256
}
245257
}
246-
247-
// unpackType will unwrap the type, iterating through pointers and interfaces until the real type has been discovered.
248-
// It returns the source type and the parent type to it.
249-
func unpackType(t reflect.Type, parent reflect.Type) (reflect.Type, reflect.Type) {
250-
for t.Kind() == reflect.Pointer {
251-
parent = t
252-
t = t.Elem() // then unpack all the pointers to get the real core type
253-
}
254-
return t, parent
255-
}
256-
257-
// unpackValue will unwrap the provided value, iterating through pointers until the real value has been discovered.
258-
// It returns the source value and the parent value to it.
259-
//
260-
// We'll initialize pointers as we go, but not the inner most pointer. Callers must check the returned parent value
261-
// for nil, before they try and set any values on the returned value.
262-
func unpackValue(value reflect.Value) (reflect.Value, reflect.Value) {
263-
parent := value.Addr()
264-
for value.Kind() == reflect.Pointer {
265-
parent = value
266-
if value.IsNil() {
267-
value.Set(reflect.New(value.Type().Elem()))
268-
}
269-
value = value.Elem()
270-
}
271-
return value, parent
272-
}

0 commit comments

Comments
 (0)