diff --git a/src/cmd/cli/command/commands.go b/src/cmd/cli/command/commands.go index 4a4451a45..9eb0c223e 100644 --- a/src/cmd/cli/command/commands.go +++ b/src/cmd/cli/command/commands.go @@ -94,7 +94,6 @@ func Execute(ctx context.Context) error { if strings.Contains(err.Error(), "config") { printDefangHint("To manage sensitive service config, use:", "config") } - if strings.Contains(err.Error(), "maximum number of projects") { projectName := "" provider, err := newProvider(ctx, nil) @@ -207,7 +206,8 @@ func SetupCommands(ctx context.Context, version string) { RootCmd.AddCommand(tokenCmd) // Login Command - loginCmd.Flags().Bool("training-opt-out", false, "Opt out of ML training (Pro users only)") + loginCmd.Flags().Bool("training-opt-out", false, "opt out of ML training (Pro users only)") + loginCmd.Flags().String("token", "", "access token to use for authentication") // loginCmd.Flags().Bool("skip-prompt", false, "skip the login prompt if already logged in"); TODO: Implement this RootCmd.AddCommand(loginCmd) @@ -396,9 +396,10 @@ var loginCmd = &cobra.Command{ Short: "Authenticate to Defang", RunE: func(cmd *cobra.Command, args []string) error { trainingOptOut, _ := cmd.Flags().GetBool("training-opt-out") + token, _ := cmd.Flags().GetString("token") if nonInteractive { - if err := login.NonInteractiveGitHubLogin(cmd.Context(), client, getCluster()); err != nil { + if err := login.NonInteractiveLogin(cmd.Context(), client, getCluster(), token); err != nil { return err } } else { diff --git a/src/pkg/auth/auth.go b/src/pkg/auth/auth.go index 6f951b950..f5a8498b2 100644 --- a/src/pkg/auth/auth.go +++ b/src/pkg/auth/auth.go @@ -236,3 +236,13 @@ func ExchangeCodeForToken(ctx context.Context, code AuthCodeFlow, tenant types.T } return token.AccessToken, nil } + +func ExchangeJWTForToken(ctx context.Context, jwt string) (string, error) { + term.Debugf("Generating token for jwt %q", jwt) + + token, err := openAuthClient.ExchangeJWT(jwt) // TODO: scopes, TTL + if err != nil { + return "", err + } + return token.AccessToken, nil +} diff --git a/src/pkg/auth/client.go b/src/pkg/auth/client.go index 5efa511d1..b2a08bf5f 100644 --- a/src/pkg/auth/client.go +++ b/src/pkg/auth/client.go @@ -25,6 +25,7 @@ const ( var ( ErrInvalidAccessToken = errors.New("invalid access token") ErrInvalidAuthorizationCode = errors.New("invalid authorization code") + ErrInvalidJWT = errors.New("invalid JWT") ErrInvalidRefreshToken = errors.New("invalid refresh token") ) @@ -124,6 +125,10 @@ type Client interface { * Exchange the code for access and refresh tokens. */ Exchange(code string, redirectURI string, verifier string) (*ExchangeSuccess, error) + /** + * Exchange jwt for access and refresh tokens. + */ + ExchangeJWT(jwt string) (*ExchangeSuccess, error) /** * Refreshes the tokens if they have expired. This is used in an SPA app to maintain the * session, without logging the user out. @@ -206,6 +211,20 @@ func (c client) callToken(body url.Values) (*Tokens, error) { return &tokens.Tokens, nil } +/** + * Helper function to exchange tokens with common error handling. + */ +func (c client) exchangeForTokens(body url.Values) (*ExchangeSuccess, error) { + tokens, err := c.callToken(body) + if err != nil { + return nil, err + } + + return &ExchangeSuccess{ + Tokens: *tokens, + }, nil +} + /** * Exchange the code for access and refresh tokens. */ @@ -217,18 +236,40 @@ func (c client) Exchange(code string, redirectURI string, verifier string) (*Exc "grant_type": {"authorization_code"}, "redirect_uri": {redirectURI}, } - tokens, err := c.callToken(body) + + result, err := c.exchangeForTokens(body) if err != nil { var oauthError *OAuthError if errors.As(err, &oauthError) { return nil, fmt.Errorf("%w: %w", ErrInvalidAuthorizationCode, err) } + return nil, fmt.Errorf("token exchange failed: %w", err) } - return &ExchangeSuccess{ - Tokens: *tokens, - }, nil + return result, nil +} + +/** + * Exchange the JWT for access and refresh tokens. + */ +func (c client) ExchangeJWT(jwt string) (*ExchangeSuccess, error) { + body := url.Values{ + "grant_type": {"urn:ietf:params:oauth:grant-type:jwt-bearer"}, + "assertion": {jwt}, + } + result, err := c.exchangeForTokens(body) + + if err != nil { + var oauthError *OAuthError + if errors.As(err, &oauthError) { + return nil, fmt.Errorf("%w: %w", ErrInvalidJWT, err) + } + + return nil, fmt.Errorf("token exchange failed: %w", err) + } + + return result, nil } /** diff --git a/src/pkg/auth/client_test.go b/src/pkg/auth/client_test.go index 3d5f2e392..888660b16 100644 --- a/src/pkg/auth/client_test.go +++ b/src/pkg/auth/client_test.go @@ -63,6 +63,54 @@ func TestExchange(t *testing.T) { }) } +func TestExchangeJWT(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + if r.Method != http.MethodPost { + t.Errorf("Expected POST method, got %s", r.Method) + } + if expected, got := "urn:ietf:params:oauth:grant-type:jwt-bearer", r.PostFormValue("grant_type"); expected != got { + t.Errorf("Expected grant_type %s, got: %s", expected, got) + } + + jwt := r.PostFormValue("assertion") + if jwt == "valid-jwt" { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"jwt-access-token","refresh_token":"jwt-refresh-token"}`)) + } else { + w.Write([]byte(`{"error":"invalid_request","error_description":"Invalid request"}`)) + } + default: + http.Error(w, "Not Found", http.StatusNotFound) + } + })) + t.Cleanup(server.CloseClientConnections) + + client := NewClient("defang-cli", server.URL) + + t.Run("success", func(t *testing.T) { + result, err := client.ExchangeJWT("valid-jwt") + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if result.AccessToken != "jwt-access-token" { + t.Errorf("Expected access token 'jwt-access-token', got: %s", result.AccessToken) + } + if result.RefreshToken != "jwt-refresh-token" { + t.Errorf("Expected refresh token 'jwt-refresh-token', got: %s", result.RefreshToken) + } + }) + + t.Run("invalid jwt", func(t *testing.T) { + _, err := client.ExchangeJWT("invalid-jwt") + const expected = "invalid JWT: Invalid request" + if err.Error() != expected { + t.Fatalf("Expected error %q, got: %v", expected, err) + } + }) +} + func TestAuthorizeExchange(t *testing.T) { if testing.Short() { t.Skip("skipping browser test in short mode.") diff --git a/src/pkg/login/login.go b/src/pkg/login/login.go index 39ba0cc5b..afc04faac 100644 --- a/src/pkg/login/login.go +++ b/src/pkg/login/login.go @@ -14,7 +14,6 @@ import ( "github.com/DefangLabs/defang/src/pkg/github" "github.com/DefangLabs/defang/src/pkg/term" "github.com/DefangLabs/defang/src/pkg/track" - defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" "github.com/bufbuild/connect-go" ) @@ -94,21 +93,22 @@ func interactiveLogin(ctx context.Context, client client.FabricClient, fabric st return nil } -func NonInteractiveGitHubLogin(ctx context.Context, client client.FabricClient, fabric string) error { - term.Debug("Non-interactive login using GitHub Actions id-token") - idToken, err := github.GetIdToken(ctx) - if err != nil { - return fmt.Errorf("non-interactive login failed: %w", err) +func NonInteractiveLogin(ctx context.Context, client client.FabricClient, fabric string, token string) error { + if token == "" { + term.Debug("Non-interactive login using GitHub Actions id-token") + var err error + token, err = github.GetIdToken(ctx) + if err != nil { + return fmt.Errorf("non-interactive login failed: %w", err) + } + term.Debug("Got GitHub Actions id-token") } - term.Debug("Got GitHub Actions id-token") - resp, err := client.Token(ctx, &defangv1.TokenRequest{ - Assertion: idToken, - Scope: []string{"admin", "read", "delete", "tail"}, - }) + + accessToken, err := auth.ExchangeJWTForToken(ctx, token) if err != nil { - return err + return fmt.Errorf("non-interactive login failed: %w", err) } - return cluster.SaveAccessToken(fabric, resp.AccessToken) + return cluster.SaveAccessToken(fabric, accessToken) } func InteractiveRequireLoginAndToS(ctx context.Context, fabric client.FabricClient, addr string) error { diff --git a/src/pkg/login/login_test.go b/src/pkg/login/login_test.go index 092db3919..6490f8d09 100644 --- a/src/pkg/login/login_test.go +++ b/src/pkg/login/login_test.go @@ -108,6 +108,7 @@ func TestNonInteractiveLogin(t *testing.T) { ctx := context.Background() mockClient := &MockForNonInteractiveLogin{} fabric := "test.defang.dev" + token := "" t.Run("Expect accessToken to be stored when NonInteractiveLogin() succeeds", func(t *testing.T) { requestUrl := os.Getenv("ACTIONS_ID_TOKEN_REQUEST_URL") @@ -122,7 +123,7 @@ func TestNonInteractiveLogin(t *testing.T) { t.Cleanup(func() { client.StateDir = prevStateDir }) - err := NonInteractiveGitHubLogin(ctx, mockClient, fabric) + err := NonInteractiveLogin(ctx, mockClient, fabric, token) if err != nil { t.Fatalf("expected no error, got %v", err) } @@ -139,7 +140,7 @@ func TestNonInteractiveLogin(t *testing.T) { t.Run("Expect error when NonInteractiveLogin() fails in the case that GitHub Actions info is not set", func(t *testing.T) { - err := NonInteractiveGitHubLogin(ctx, mockClient, fabric) + err := NonInteractiveLogin(ctx, mockClient, fabric, token) if err != nil && err.Error() != "non-interactive login failed: ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not set" { t.Fatalf("expected no error, got %v", err)