Skip to content

Commit 635806f

Browse files
authored
CLOUDP-329791: Refactor auth check (#4034)
1 parent 8017aa3 commit 635806f

File tree

4 files changed

+31
-80
lines changed

4 files changed

+31
-80
lines changed

internal/cli/root/builder.go

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,11 @@ import (
6767
"github.com/mongodb/mongodb-atlas-cli/atlascli/internal/homebrew"
6868
"github.com/mongodb/mongodb-atlas-cli/atlascli/internal/latestrelease"
6969
"github.com/mongodb/mongodb-atlas-cli/atlascli/internal/log"
70-
"github.com/mongodb/mongodb-atlas-cli/atlascli/internal/plugin"
7170
"github.com/mongodb/mongodb-atlas-cli/atlascli/internal/prerun"
7271
"github.com/mongodb/mongodb-atlas-cli/atlascli/internal/sighandle"
7372
"github.com/mongodb/mongodb-atlas-cli/atlascli/internal/telemetry"
7473
"github.com/mongodb/mongodb-atlas-cli/atlascli/internal/terminal"
7574
"github.com/mongodb/mongodb-atlas-cli/atlascli/internal/usage"
76-
"github.com/mongodb/mongodb-atlas-cli/atlascli/internal/validate"
7775
"github.com/mongodb/mongodb-atlas-cli/atlascli/internal/version"
7876
"github.com/spf13/afero"
7977
"github.com/spf13/cobra"
@@ -88,18 +86,6 @@ type Notifier struct {
8886
writer io.Writer
8987
}
9088

91-
type AuthRequirements int64
92-
93-
const (
94-
// NoAuth command does not require authentication.
95-
NoAuth AuthRequirements = 0
96-
// RequiredAuth command requires authentication.
97-
RequiredAuth AuthRequirements = 1
98-
// OptionalAuth command can work with or without authentication,
99-
// and if access token is found, try to refresh it.
100-
OptionalAuth AuthRequirements = 2
101-
)
102-
10389
func handleSignal() {
10490
sighandle.Notify(func(sig os.Signal) {
10591
telemetry.FinishTrackingCommand(telemetry.TrackOptions{
@@ -149,29 +135,7 @@ Use the --help flag with any command for more info on that command.`,
149135
config.SetService(config.CloudService)
150136
}
151137

152-
authReq := shouldCheckCredentials(cmd)
153-
if authReq == NoAuth {
154-
return nil
155-
}
156-
157-
if err := prerun.ExecuteE(
158-
opts.InitFlow(config.Default()),
159-
func() error {
160-
err := opts.RefreshAccessToken(cmd.Context())
161-
if err != nil && authReq == RequiredAuth {
162-
return err
163-
}
164-
return nil
165-
},
166-
); err != nil {
167-
return err
168-
}
169-
170-
if authReq == RequiredAuth {
171-
return validate.Credentials()
172-
}
173-
174-
return nil
138+
return prerun.ExecuteE(opts.InitFlow(config.Default()))
175139
},
176140
// PersistentPostRun only runs if the command is successful
177141
PersistentPostRun: func(cmd *cobra.Command, _ []string) {
@@ -289,43 +253,6 @@ func shouldSetService(cmd *cobra.Command) bool {
289253
return true
290254
}
291255

292-
func shouldCheckCredentials(cmd *cobra.Command) AuthRequirements {
293-
searchByName := []string{
294-
"__complete",
295-
"help",
296-
}
297-
for _, n := range searchByName {
298-
if cmd.Name() == n {
299-
return NoAuth
300-
}
301-
}
302-
customRequirements := map[string]AuthRequirements{
303-
fmt.Sprintf("%s %s", atlas, "completion"): NoAuth, // completion commands do not require credentials
304-
fmt.Sprintf("%s %s", atlas, "config"): NoAuth, // user wants to set credentials
305-
fmt.Sprintf("%s %s", atlas, "auth"): NoAuth, // user wants to set credentials
306-
fmt.Sprintf("%s %s", atlas, "register"): NoAuth, // user wants to set credentials
307-
fmt.Sprintf("%s %s", atlas, "login"): NoAuth, // user wants to set credentials
308-
fmt.Sprintf("%s %s", atlas, "logout"): NoAuth, // user wants to set credentials
309-
fmt.Sprintf("%s %s", atlas, "whoami"): NoAuth, // user wants to set credentials
310-
fmt.Sprintf("%s %s", atlas, "setup"): NoAuth, // user wants to set credentials
311-
fmt.Sprintf("%s %s", atlas, "register"): NoAuth, // user wants to set credentials
312-
fmt.Sprintf("%s %s", atlas, "plugin"): NoAuth, // plugin functionality requires no authentication
313-
fmt.Sprintf("%s %s", atlas, "quickstart"): NoAuth, // command supports login
314-
fmt.Sprintf("%s %s", atlas, "deployments"): OptionalAuth, // command supports local and Atlas
315-
}
316-
for p, r := range customRequirements {
317-
if strings.HasPrefix(cmd.CommandPath(), p) {
318-
return r
319-
}
320-
}
321-
322-
if plugin.IsPluginCmd(cmd) || pluginCmd.IsFirstClassPluginCmd(cmd) {
323-
return OptionalAuth
324-
}
325-
326-
return RequiredAuth
327-
}
328-
329256
func formattedVersion() string {
330257
return fmt.Sprintf(verTemplate,
331258
version.Version,

internal/store/store.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ type Store struct {
5555
}
5656

5757
func (s *Store) httpClient(httpTransport http.RoundTripper) (*http.Client, error) {
58-
if s.username != "" && s.password != "" {
58+
switch {
59+
case s.username != "" && s.password != "":
5960
t := transport.NewDigestTransport(s.username, s.password, httpTransport)
6061
return t.Client()
61-
}
62-
if s.accessToken != nil {
62+
case s.accessToken != nil:
6363
tr, err := transport.NewAccessTokenTransport(s.accessToken, httpTransport, func(t *atlasauth.Token) error {
6464
config.SetAccessToken(t.AccessToken)
6565
config.SetRefreshToken(t.RefreshToken)
@@ -69,10 +69,11 @@ func (s *Store) httpClient(httpTransport http.RoundTripper) (*http.Client, error
6969
return nil, err
7070
}
7171

72+
return &http.Client{Transport: tr}, nil
73+
default:
74+
tr := &transport.AuthRequiredRoundTripper{Base: httpTransport}
7275
return &http.Client{Transport: tr}, nil
7376
}
74-
75-
return &http.Client{Transport: httpTransport}, nil
7677
}
7778

7879
func (s *Store) transport() *http.Transport {

internal/transport/transport.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
package transport
1616

1717
import (
18+
"fmt"
1819
"net"
1920
"net/http"
2021
"time"
2122

2223
"github.com/mongodb-forks/digest"
2324
"github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config"
2425
"github.com/mongodb/mongodb-atlas-cli/atlascli/internal/oauth"
26+
"github.com/mongodb/mongodb-atlas-cli/atlascli/internal/validate"
2527
atlasauth "go.mongodb.org/atlas/auth"
2628
)
2729

@@ -110,3 +112,24 @@ func (tr *tokenTransport) RoundTrip(req *http.Request) (*http.Response, error) {
110112

111113
return tr.base.RoundTrip(req)
112114
}
115+
116+
type AuthRequiredRoundTripper struct {
117+
Base http.RoundTripper
118+
}
119+
120+
func (tr *AuthRequiredRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
121+
resp, err := tr.Base.RoundTrip(req)
122+
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
123+
return nil, fmt.Errorf(
124+
`%w
125+
126+
To log in using your Atlas username and password, run: atlas auth login
127+
To set credentials using API keys, run: atlas config init`,
128+
validate.ErrMissingCredentials,
129+
)
130+
}
131+
if err != nil {
132+
return nil, err
133+
}
134+
return resp, nil
135+
}

test/e2e/brew/brew_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import (
2828

2929
const (
3030
profileString = "PROFILE NAME"
31-
errorMessage = "Error: this action requires authentication"
31+
errorMessage = "this action requires authentication"
3232
)
3333

3434
func TestAtlasCLIConfig(t *testing.T) {

0 commit comments

Comments
 (0)