Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/cmd/cli/command/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ func Execute(ctx context.Context) error {
if strings.Contains(err.Error(), "config") {
printDefangHint("To manage sensitive service config, use:", "config")
}

if strings.Contains(err.Error(), "maximum number of projects") {
projectName := "<name>"
provider, err := newProvider(ctx, nil)
Expand Down Expand Up @@ -207,7 +206,8 @@ func SetupCommands(ctx context.Context, version string) {
RootCmd.AddCommand(tokenCmd)

// Login Command
loginCmd.Flags().Bool("training-opt-out", false, "Opt out of ML training (Pro users only)")
loginCmd.Flags().Bool("training-opt-out", false, "opt out of ML training (Pro users only)")
loginCmd.Flags().String("token", "", "access token to use for authentication")
// loginCmd.Flags().Bool("skip-prompt", false, "skip the login prompt if already logged in"); TODO: Implement this
RootCmd.AddCommand(loginCmd)

Expand Down Expand Up @@ -396,9 +396,10 @@ var loginCmd = &cobra.Command{
Short: "Authenticate to Defang",
RunE: func(cmd *cobra.Command, args []string) error {
trainingOptOut, _ := cmd.Flags().GetBool("training-opt-out")
token, _ := cmd.Flags().GetString("token")

if nonInteractive {
if err := login.NonInteractiveGitHubLogin(cmd.Context(), client, getCluster()); err != nil {
if err := login.NonInteractiveLogin(cmd.Context(), client, getCluster(), token); err != nil {
return err
}
} else {
Expand Down
10 changes: 10 additions & 0 deletions src/pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,13 @@ func ExchangeCodeForToken(ctx context.Context, code AuthCodeFlow, tenant types.T
}
return token.AccessToken, nil
}

func ExchangeJWTForToken(ctx context.Context, jwt string) (string, error) {
term.Debugf("Generating token for jwt %q", jwt)

token, err := openAuthClient.ExchangeJWT(jwt) // TODO: scopes, TTL
if err != nil {
return "", err
}
return token.AccessToken, nil
}
49 changes: 45 additions & 4 deletions src/pkg/auth/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const (
var (
ErrInvalidAccessToken = errors.New("invalid access token")
ErrInvalidAuthorizationCode = errors.New("invalid authorization code")
ErrInvalidJWT = errors.New("invalid JWT")
ErrInvalidRefreshToken = errors.New("invalid refresh token")
)

Expand Down Expand Up @@ -124,6 +125,10 @@ type Client interface {
* Exchange the code for access and refresh tokens.
*/
Exchange(code string, redirectURI string, verifier string) (*ExchangeSuccess, error)
/**
* Exchange jwt for access and refresh tokens.
*/
ExchangeJWT(jwt string) (*ExchangeSuccess, error)
/**
* Refreshes the tokens if they have expired. This is used in an SPA app to maintain the
* session, without logging the user out.
Expand Down Expand Up @@ -206,6 +211,20 @@ func (c client) callToken(body url.Values) (*Tokens, error) {
return &tokens.Tokens, nil
}

/**
* Helper function to exchange tokens with common error handling.
*/
func (c client) exchangeForTokens(body url.Values) (*ExchangeSuccess, error) {
tokens, err := c.callToken(body)
if err != nil {
return nil, err
}

return &ExchangeSuccess{
Tokens: *tokens,
}, nil
}

/**
* Exchange the code for access and refresh tokens.
*/
Expand All @@ -217,18 +236,40 @@ func (c client) Exchange(code string, redirectURI string, verifier string) (*Exc
"grant_type": {"authorization_code"},
"redirect_uri": {redirectURI},
}
tokens, err := c.callToken(body)

result, err := c.exchangeForTokens(body)
if err != nil {
var oauthError *OAuthError
if errors.As(err, &oauthError) {
return nil, fmt.Errorf("%w: %w", ErrInvalidAuthorizationCode, err)
}

return nil, fmt.Errorf("token exchange failed: %w", err)
}

return &ExchangeSuccess{
Tokens: *tokens,
}, nil
return result, nil
}

/**
* Exchange the JWT for access and refresh tokens.
*/
func (c client) ExchangeJWT(jwt string) (*ExchangeSuccess, error) {
body := url.Values{
"grant_type": {"urn:ietf:params:oauth:grant-type:jwt-bearer"},
"assertion": {jwt},
}
result, err := c.exchangeForTokens(body)

if err != nil {
var oauthError *OAuthError
if errors.As(err, &oauthError) {
return nil, fmt.Errorf("%w: %w", ErrInvalidJWT, err)
}

return nil, fmt.Errorf("token exchange failed: %w", err)
}

return result, nil
}

/**
Expand Down
48 changes: 48 additions & 0 deletions src/pkg/auth/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,54 @@ func TestExchange(t *testing.T) {
})
}

func TestExchangeJWT(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
if r.Method != http.MethodPost {
t.Errorf("Expected POST method, got %s", r.Method)
}
if expected, got := "urn:ietf:params:oauth:grant-type:jwt-bearer", r.PostFormValue("grant_type"); expected != got {
t.Errorf("Expected grant_type %s, got: %s", expected, got)
}

jwt := r.PostFormValue("assertion")
if jwt == "valid-jwt" {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"access_token":"jwt-access-token","refresh_token":"jwt-refresh-token"}`))
} else {
w.Write([]byte(`{"error":"invalid_request","error_description":"Invalid request"}`))
}
default:
http.Error(w, "Not Found", http.StatusNotFound)
}
}))
t.Cleanup(server.CloseClientConnections)

client := NewClient("defang-cli", server.URL)

t.Run("success", func(t *testing.T) {
result, err := client.ExchangeJWT("valid-jwt")
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if result.AccessToken != "jwt-access-token" {
t.Errorf("Expected access token 'jwt-access-token', got: %s", result.AccessToken)
}
if result.RefreshToken != "jwt-refresh-token" {
t.Errorf("Expected refresh token 'jwt-refresh-token', got: %s", result.RefreshToken)
}
})

t.Run("invalid jwt", func(t *testing.T) {
_, err := client.ExchangeJWT("invalid-jwt")
const expected = "invalid JWT: Invalid request"
if err.Error() != expected {
t.Fatalf("Expected error %q, got: %v", expected, err)
}
})
}

func TestAuthorizeExchange(t *testing.T) {
if testing.Short() {
t.Skip("skipping browser test in short mode.")
Expand Down
26 changes: 13 additions & 13 deletions src/pkg/login/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/DefangLabs/defang/src/pkg/github"
"github.com/DefangLabs/defang/src/pkg/term"
"github.com/DefangLabs/defang/src/pkg/track"
defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1"
"github.com/bufbuild/connect-go"
)

Expand Down Expand Up @@ -94,21 +93,22 @@ func interactiveLogin(ctx context.Context, client client.FabricClient, fabric st
return nil
}

func NonInteractiveGitHubLogin(ctx context.Context, client client.FabricClient, fabric string) error {
term.Debug("Non-interactive login using GitHub Actions id-token")
idToken, err := github.GetIdToken(ctx)
if err != nil {
return fmt.Errorf("non-interactive login failed: %w", err)
func NonInteractiveLogin(ctx context.Context, client client.FabricClient, fabric string, token string) error {
if token == "" {
term.Debug("Non-interactive login using GitHub Actions id-token")
var err error
token, err = github.GetIdToken(ctx)
if err != nil {
return fmt.Errorf("non-interactive login failed: %w", err)
}
term.Debug("Got GitHub Actions id-token")
}
term.Debug("Got GitHub Actions id-token")
resp, err := client.Token(ctx, &defangv1.TokenRequest{
Assertion: idToken,
Scope: []string{"admin", "read", "delete", "tail"},
})

accessToken, err := auth.ExchangeJWTForToken(ctx, token)
if err != nil {
return err
return fmt.Errorf("non-interactive login failed: %w", err)
}
return cluster.SaveAccessToken(fabric, resp.AccessToken)
return cluster.SaveAccessToken(fabric, accessToken)
}

func InteractiveRequireLoginAndToS(ctx context.Context, fabric client.FabricClient, addr string) error {
Expand Down
5 changes: 3 additions & 2 deletions src/pkg/login/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ func TestNonInteractiveLogin(t *testing.T) {
ctx := context.Background()
mockClient := &MockForNonInteractiveLogin{}
fabric := "test.defang.dev"
token := ""

t.Run("Expect accessToken to be stored when NonInteractiveLogin() succeeds", func(t *testing.T) {
requestUrl := os.Getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
Expand All @@ -122,7 +123,7 @@ func TestNonInteractiveLogin(t *testing.T) {

t.Cleanup(func() { client.StateDir = prevStateDir })

err := NonInteractiveGitHubLogin(ctx, mockClient, fabric)
err := NonInteractiveLogin(ctx, mockClient, fabric, token)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
Expand All @@ -139,7 +140,7 @@ func TestNonInteractiveLogin(t *testing.T) {

t.Run("Expect error when NonInteractiveLogin() fails in the case that GitHub Actions info is not set",
func(t *testing.T) {
err := NonInteractiveGitHubLogin(ctx, mockClient, fabric)
err := NonInteractiveLogin(ctx, mockClient, fabric, token)
if err != nil &&
err.Error() != "non-interactive login failed: ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not set" {
t.Fatalf("expected no error, got %v", err)
Expand Down
Loading