|
1 | 1 | package common |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "bytes" |
| 5 | + "encoding/json" |
4 | 6 | "fmt" |
5 | 7 | "io/ioutil" |
6 | 8 | "os" |
@@ -83,34 +85,68 @@ func LoadFileIntoVar(filePath string, destContent *string) error { |
83 | 85 | return nil |
84 | 86 | } |
85 | 87 |
|
86 | | -// WriteConfigFile writes new content of the config file. |
87 | | -// If the file does not exist, it is created at default location |
88 | | -// TODO temporary solution until upstream https://github.com/spf13/viper/issues/433 is fixed |
89 | | -func WriteConfigFile() error { |
90 | | - cf := viper.ConfigFileUsed() |
| 88 | +// return the file handle of a config file |
| 89 | +// if it does not exist yet, creates a new one at default location |
| 90 | +func getCurrentOrNewConfigFile() (string, error) { |
91 | 91 |
|
| 92 | + cf := viper.ConfigFileUsed() |
92 | 93 | if cf == "" { |
93 | 94 | fullname := ConfigFileName + "." + ConfigFileType |
94 | 95 | if dirname, err := os.UserHomeDir(); err == nil { |
95 | 96 | cf = filepath.Join(dirname, ConfigHomeSubdir, ConfigFuseMLSubdir, fullname) |
96 | 97 | } |
97 | 98 | if cf == "" { |
98 | | - return errors.New("Failed to acquire config directory name") |
| 99 | + return "", errors.New("Failed to acquire config directory name") |
99 | 100 | } |
100 | 101 | configDirPath := filepath.Dir(cf) |
101 | 102 | if err := os.MkdirAll(configDirPath, os.ModePerm); err != nil { |
102 | | - return err |
| 103 | + return "", err |
103 | 104 | } |
104 | 105 |
|
105 | 106 | fmt.Printf("FuseML configuration file created at %s\n", cf) |
106 | 107 | } |
| 108 | + return cf, nil |
| 109 | +} |
| 110 | + |
| 111 | +// WriteConfigFile writes new content of the config file. |
| 112 | +// If the file does not exist, it is created at default location |
| 113 | +// TODO temporary solution until upstream https://github.com/spf13/viper/issues/433 is fixed |
| 114 | +func WriteConfigFile() error { |
| 115 | + cf, err := getCurrentOrNewConfigFile() |
| 116 | + if err != nil { |
| 117 | + return err |
| 118 | + } |
107 | 119 |
|
108 | 120 | if err := viper.WriteConfigAs(cf); err != nil { |
109 | 121 | return err |
110 | 122 | } |
111 | 123 | return nil |
112 | 124 | } |
113 | 125 |
|
| 126 | +// Writes the current content of file and also deletes given key from it |
| 127 | +// As viper does not support this operation directly, here's a workaround |
| 128 | +// taken from https://github.com/spf13/viper/issues/632 |
| 129 | +func DeleteKeyAndWriteConfigFile(key string) error { |
| 130 | + cf, err := getCurrentOrNewConfigFile() |
| 131 | + if err != nil { |
| 132 | + return err |
| 133 | + } |
| 134 | + |
| 135 | + configMap := viper.AllSettings() |
| 136 | + |
| 137 | + delete(configMap, strings.ToLower(key)) |
| 138 | + encodedConfig, _ := json.MarshalIndent(configMap, "", " ") |
| 139 | + |
| 140 | + err = viper.ReadConfig(bytes.NewReader(encodedConfig)) |
| 141 | + if err != nil { |
| 142 | + return err |
| 143 | + } |
| 144 | + if err := viper.WriteConfigAs(cf); err != nil { |
| 145 | + return err |
| 146 | + } |
| 147 | + return nil |
| 148 | +} |
| 149 | + |
114 | 150 | // ValidateEnumArgument is used to validate command line arguments that can take a limited set of values |
115 | 151 | func ValidateEnumArgument(argName, argValue string, values []string) error { |
116 | 152 | if !util.StringInSlice(argValue, values) { |
|
0 commit comments