Skip to content

Commit f805290

Browse files
authored
Fix problems with reading config files using --config (#1230)
* Fix problems with reading config files using `--config` * Add test to verify that we can read config files outside of ~/.config
1 parent be22021 commit f805290

File tree

2 files changed

+49
-9
lines changed

2 files changed

+49
-9
lines changed

pkg/config/config.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ var (
3737
// `cmd.PersistentFlags().AddFlagSet(connection.Flags)`
3838
Flags *pflag.FlagSet = CreateFlagSet()
3939

40-
// Directory is $HOME/config/registry
40+
// Directory is $HOME/.config/registry
4141
Directory string
4242
ErrCannotDeleteActive = fmt.Errorf("cannot delete active configuration")
4343
ErrReservedConfigName = fmt.Errorf("%q is reserved", ActivePointerFilename)
@@ -100,10 +100,6 @@ func ValidateName(name string) error {
100100
if name == ActivePointerFilename {
101101
return ErrReservedConfigName
102102
}
103-
104-
if dir, _ := filepath.Split(name); dir != "" {
105-
return fmt.Errorf("%q must not include a path", name)
106-
}
107103
return nil
108104
}
109105

@@ -199,9 +195,10 @@ func ReadValid(name string) (c Configuration, err error) {
199195
}
200196

201197
// Read loads a Configuration from yaml file matching `name`. If name
202-
// contains a path, the file will be read from that path, otherwise
203-
// the path is assumed as: ~/.config/registry. Does a simple read from the
204-
// file: does not bind to env vars or flags, resolve, or validate.
198+
// contains a path or refers to a local file, the file will be read
199+
// using name, otherwise the path is assumed as: ~/.config/registry.
200+
// Does a simple read from the file: does not bind to env vars or flags,
201+
// resolve, or validate.
205202
// See also: ReadValid()
206203
func Read(name string) (c Configuration, err error) {
207204
if err = ValidateName(name); err != nil {
@@ -210,7 +207,11 @@ func Read(name string) (c Configuration, err error) {
210207

211208
dir, file := filepath.Split(name)
212209
if dir == "" {
213-
name = filepath.Join(Directory, file)
210+
// If name refers to a local file, preferentially read the local file.
211+
// Otherwise assume name refers to a file in the config directory.
212+
if info, err := os.Stat(file); errors.Is(err, os.ErrNotExist) || info.IsDir() {
213+
name = filepath.Join(Directory, file)
214+
}
214215
}
215216
var r io.Reader
216217
if r, err = os.Open(name); err != nil {

pkg/config/configuration_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package config_test
1616

1717
import (
18+
"log"
1819
"os"
1920
"path"
2021
"path/filepath"
@@ -24,6 +25,7 @@ import (
2425
"github.com/apigee/registry/pkg/config"
2526
"github.com/apigee/registry/pkg/config/test"
2627
"github.com/google/go-cmp/cmp"
28+
"gopkg.in/yaml.v2"
2729
)
2830

2931
func TestMissingDirectory(t *testing.T) {
@@ -476,3 +478,40 @@ func TestResolve(t *testing.T) {
476478
t.Errorf("want: %s, got: %s", "hello", c.Registry.Token)
477479
}
478480
}
481+
482+
func TestReadExternalFile(t *testing.T) {
483+
c := config.Configuration{
484+
Registry: config.Registry{Address: "example.com:443"},
485+
}
486+
bytes, err := yaml.Marshal(c)
487+
if err != nil {
488+
t.Fatal(err)
489+
}
490+
f, err := os.CreateTemp(".", "tmpfile-")
491+
if err != nil {
492+
log.Fatal(err)
493+
}
494+
if _, err = f.Write(bytes); err != nil {
495+
log.Fatal(err)
496+
}
497+
if err = f.Close(); err != nil {
498+
log.Fatal(err)
499+
}
500+
defer os.Remove(f.Name())
501+
// Verify that we can read a file using its full path name.
502+
c2, err := config.Read(f.Name())
503+
if err != nil {
504+
t.Fatal(err)
505+
}
506+
if c2.Registry.Address != c.Registry.Address {
507+
t.Errorf("want: %s, got: %s", c2.Registry.Address, c.Registry.Address)
508+
}
509+
// Verify that we can read a local file using its base name.
510+
c3, err := config.Read(filepath.Base(f.Name()))
511+
if err != nil {
512+
t.Fatal(err)
513+
}
514+
if c3.Registry.Address != c.Registry.Address {
515+
t.Errorf("want: %s, got: %s", c3.Registry.Address, c.Registry.Address)
516+
}
517+
}

0 commit comments

Comments
 (0)