Skip to content

Commit 761bdb5

Browse files
authored
Merge pull request #117 from xia0pin9/feat/improve-register-command
feat(client): improve register command error handling
2 parents 1039c19 + d818d42 commit 761bdb5

File tree

3 files changed

+383
-39
lines changed

3 files changed

+383
-39
lines changed

client/register.go

Lines changed: 86 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,17 @@ Register will cache the key in the file system and keep it up to date using the
2323
-k specifies a specific key identifier to register
2424
-f specifies a file containing a new line separated list of key identifiers
2525
-t specifies a timeout for getting the key from the daemon (e.g. '5s', '500ms')
26-
-g gets the key as well
26+
-g gets the key as well (in no-cache mode, only -g with -k is supported)
2727
2828
For a machine to access a certain key, it needs permissions on that key.
2929
3030
Note that knox register will only update the register file and will return successful
3131
even if the machine does not have access to the key. The daemon will actually retrieve
3232
the key.
3333
34+
In no-cache mode (when using UncachedHTTPClient), only the -g flag with -k is supported
35+
to fetch a key directly from the server without local registration.
36+
3437
For more about knox, see https://github.com/pinterest/knox.
3538
3639
See also: knox unregister, knox daemon
@@ -58,51 +61,61 @@ func parseTimeout(val string) (time.Duration, error) {
5861
}
5962

6063
func runRegister(cmd *Command, args []string) *ErrorStatus {
61-
if _, ok := cli.(*knox.UncachedHTTPClient); ok {
62-
fmt.Println("Cannot Register in No Cache mode")
63-
return nil
64+
// Validate -g flag usage early: -g requires -k (cannot be used with -f alone)
65+
if *registerAndGet && *registerKey == "" {
66+
return &ErrorStatus{fmt.Errorf("the -g flag requires -k to specify a single key to retrieve"), false}
6467
}
65-
timeout, err := parseTimeout(*registerTimeout)
66-
if err != nil {
67-
return &ErrorStatus{fmt.Errorf("Invalid value for timeout flag: %s", err.Error()), false}
68+
69+
_, isUncachedMode := cli.(*knox.UncachedHTTPClient)
70+
71+
// In uncached mode, only support -g with -k to fetch a key directly from the server
72+
if isUncachedMode {
73+
if !*registerAndGet {
74+
return &ErrorStatus{fmt.Errorf("cannot register keys in no-cache mode; use -g with -k to fetch a key directly"), false}
75+
}
76+
// -k is already validated above
77+
// Skip registration, go directly to fetching the key
78+
return fetchAndPrintKey(*registerKey, *registerTimeout)
6879
}
6980

7081
k := NewKeysFile(path.Join(daemonFolder, daemonToRegister))
82+
// Handle `knox register -r` (without -k or -f) to remove all registered keys
7183
if *registerRemove && *registerKey == "" && *registerKeyFile == "" {
72-
// Short circuit & handle `knox register -r`, which is expected to remove all keys
7384
err := k.Lock()
7485
if err != nil {
75-
return &ErrorStatus{fmt.Errorf("There was an error obtaining file lock: %s", err.Error()), false}
86+
return &ErrorStatus{fmt.Errorf("error obtaining file lock: %w", err), false}
7687
}
7788
err = k.Overwrite([]string{})
7889
if err != nil {
7990
k.Unlock()
80-
return &ErrorStatus{fmt.Errorf("Failed to unregister all keys: %s", err.Error()), false}
91+
return &ErrorStatus{fmt.Errorf("failed to unregister all keys: %w", err), false}
8192
}
8293
err = k.Unlock()
8394
if err != nil {
84-
return &ErrorStatus{fmt.Errorf("There was an error unlocking register file: %s", err.Error()), false}
95+
return &ErrorStatus{fmt.Errorf("error unlocking register file: %w", err), false}
8596
}
8697
logf("Successfully unregistered all keys.")
8798
return nil
8899
} else if *registerKey == "" && *registerKeyFile == "" {
89-
return &ErrorStatus{fmt.Errorf("You must include a key or key file to register. see 'knox help register'"), false}
100+
return &ErrorStatus{fmt.Errorf("you must include a key or key file to register; see 'knox help register'"), false}
90101
}
91102
// Get the list of keys to add
92103
var ks []string
104+
var err error
93105
if *registerKey == "" {
94106
f := NewKeysFile(*registerKeyFile)
95107
ks, err = f.Get()
96108
if err != nil {
97-
return &ErrorStatus{fmt.Errorf("There was an error reading input key file %s", err.Error()), false}
109+
return &ErrorStatus{fmt.Errorf("error reading input key file: %w", err), false}
98110
}
99111
} else {
100112
ks = []string{*registerKey}
101113
}
102114
// Handle adding new keys to the registered file
115+
// When -r is specified with -k or -f, this replaces all registered keys with the specified ones
103116
err = k.Lock()
104117
if err != nil {
105-
return &ErrorStatus{fmt.Errorf("There was an error obtaining file lock: %s", err.Error()), false}
118+
return &ErrorStatus{fmt.Errorf("error obtaining file lock: %w", err), false}
106119
}
107120
if *registerRemove {
108121
logf("Attempting to overwrite existing keys with %v.", ks)
@@ -112,34 +125,74 @@ func runRegister(cmd *Command, args []string) *ErrorStatus {
112125
}
113126
if err != nil {
114127
k.Unlock()
115-
return &ErrorStatus{fmt.Errorf("There was an error registering keys %v: %s", ks, err.Error()), false}
128+
return &ErrorStatus{fmt.Errorf("error registering keys %v: %w", ks, err), false}
116129
}
117130
err = k.Unlock()
118131
if err != nil {
119-
return &ErrorStatus{fmt.Errorf("There was an error unlocking register file: %s", err.Error()), false}
132+
return &ErrorStatus{fmt.Errorf("error unlocking register file: %w", err), false}
120133
}
121-
// If specified, force retrieval of keys
134+
// If specified, force retrieval of keys (already validated that -k is set when -g is used)
122135
if *registerAndGet {
123-
key, err := cli.CacheGetKey(*registerKey)
124-
c := time.After(timeout)
125-
for err != nil {
126-
select {
127-
case <-c:
136+
return fetchAndPrintKey(*registerKey, *registerTimeout)
137+
}
138+
logf("Successfully registered keys %v. Keys are updated by the daemon process every %.0f minutes. Check the log for the most recent run.", ks, daemonRefreshTime.Minutes())
139+
return nil
140+
}
141+
142+
// fetchAndPrintKey fetches a key from the server and prints it as JSON.
143+
// This is used by both cached and uncached modes when -g flag is specified.
144+
//
145+
// Note: The timeout bounds the total retry time, but individual CacheGetKey calls
146+
// may block beyond the deadline since CacheGetKey doesn't support context cancellation.
147+
// The deadline is checked before each retry attempt to minimize overage.
148+
func fetchAndPrintKey(keyID string, timeoutStr string) *ErrorStatus {
149+
timeout, err := parseTimeout(timeoutStr)
150+
if err != nil {
151+
return &ErrorStatus{fmt.Errorf("invalid value for timeout flag: %w", err), false}
152+
}
153+
154+
// Start deadline timer before first call to bound total time
155+
deadline := time.After(timeout)
156+
var key *knox.Key
157+
var fetchErr error // Track fetch errors separately for clarity
158+
159+
for {
160+
// Check timeout before each attempt
161+
select {
162+
case <-deadline:
163+
if fetchErr != nil {
128164
return &ErrorStatus{fmt.Errorf(
129-
"Error getting key from daemon (hit timeout after %s seconds); check knox logs for details (most recent error: %v)",
130-
timeout.String(), err), false}
131-
case <-time.After(registerRecheckTime):
132-
key, err = cli.CacheGetKey(*registerKey)
165+
"error getting key from server (hit timeout after %s): %w",
166+
timeout.String(), fetchErr), false}
133167
}
168+
// Timeout on first attempt before any fetch was made
169+
return &ErrorStatus{fmt.Errorf(
170+
"error getting key from server (hit timeout after %s before fetch attempt)",
171+
timeout.String()), false}
172+
default:
173+
// Continue with fetch attempt
134174
}
135-
// TODO: add json vs data option?
136-
data, err := json.Marshal(key)
137-
if err != nil {
138-
return &ErrorStatus{err, true}
175+
176+
key, fetchErr = cli.CacheGetKey(keyID)
177+
if fetchErr == nil {
178+
break
179+
}
180+
181+
// Wait before retry, but also check deadline
182+
select {
183+
case <-deadline:
184+
return &ErrorStatus{fmt.Errorf(
185+
"error getting key from server (hit timeout after %s): %w",
186+
timeout.String(), fetchErr), false}
187+
case <-time.After(registerRecheckTime):
188+
// Continue to next attempt
139189
}
140-
fmt.Printf("%s", string(data))
141-
return nil
142190
}
143-
logf("Successfully registered keys %v. Keys are updated by the daemon process every %.0f minutes. Check the log for the most recent run.", ks, daemonRefreshTime.Minutes())
191+
192+
data, marshalErr := json.Marshal(key)
193+
if marshalErr != nil {
194+
return &ErrorStatus{marshalErr, true}
195+
}
196+
fmt.Printf("%s", string(data))
144197
return nil
145198
}

0 commit comments

Comments
 (0)