diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index 549d68ab..e90aa5c4 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -359,7 +359,7 @@ func New(authority, clientID string, cred Credential, options ...Option) (Client // authCodeURLOptions contains options for AuthCodeURL type authCodeURLOptions struct { - claims, loginHint, tenantID, domainHint string + claims, loginHint, tenantID, domainHint, prompt string } // AuthCodeURLOption is implemented by options for AuthCodeURL @@ -369,7 +369,7 @@ type AuthCodeURLOption interface { // AuthCodeURL creates a URL used to acquire an authorization code. Users need to call CreateAuthorizationCodeURLParameters and pass it in. // -// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID] +// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID], [WithPrompt] func (cca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) { o := authCodeURLOptions{} if err := options.ApplyOptions(&o, opts); err != nil { @@ -382,6 +382,7 @@ func (cca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, ap.Claims = o.claims ap.LoginHint = o.loginHint ap.DomainHint = o.domainHint + ap.Prompt = o.prompt return cca.base.AuthCodeURL(ctx, clientID, redirectURI, scopes, ap) } @@ -431,6 +432,29 @@ func WithDomainHint(domain string) interface { } } +// WithPrompt adds prompt query parameter in the auth url. +func WithPrompt(prompt shared.Prompt) interface { + AuthCodeURLOption + options.CallOption +} { + return struct { + AuthCodeURLOption + options.CallOption + }{ + CallOption: options.NewCallOption( + func(a any) error { + switch t := a.(type) { + case *authCodeURLOptions: + t.prompt = prompt.String() + default: + return fmt.Errorf("unexpected options type %T", a) + } + return nil + }, + ), + } +} + // WithClaims sets additional claims to request for the token, such as those required by conditional access policies. // Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded. // This option is valid for any token acquisition method. diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 165a662f..23e68afe 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -33,6 +33,7 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" ) // errorClient is an HTTP client for tests that should fail when confidential.Client sends a request @@ -1774,6 +1775,58 @@ func TestWithDomainHint(t *testing.T) { } } +func TestWithPrompt(t *testing.T) { + prompt := shared.PromptLogin + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) + if err != nil { + t.Fatal(err) + } + if err != nil { + t.Fatal(err) + } + client.base.Token.AccessTokens = &fake.AccessTokens{} + client.base.Token.Authority = &fake.Authority{} + client.base.Token.Resolver = &fake.ResolveEndpoints{} + for _, expectPrompt := range []bool{true, false} { + t.Run(fmt.Sprint(expectPrompt), func(t *testing.T) { + validate := func(v url.Values) error { + if !v.Has("prompt") { + if !expectPrompt { + return nil + } + return errors.New("expected a prompt") + } else if !expectPrompt { + return fmt.Errorf("expected no prompt, got %v", v["prompt"][0]) + } + + if actual := v["prompt"]; len(actual) != 1 || actual[0] != prompt.String() { + err = fmt.Errorf(`unexpected prompt "%v"`, actual[0]) + } + return err + } + var urlOpts []AuthCodeURLOption + if expectPrompt { + urlOpts = append(urlOpts, WithPrompt(prompt)) + } + u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) + if err == nil { + var parsed *url.URL + parsed, err = url.Parse(u) + if err == nil { + err = validate(parsed.Query()) + } + } + if err != nil { + t.Fatal(err) + } + }) + } +} + func TestWithAuthenticationScheme(t *testing.T) { ctx := context.Background() authScheme := mock.NewTestAuthnScheme() diff --git a/apps/internal/shared/shared.go b/apps/internal/shared/shared.go index d8ab7135..77376d6f 100644 --- a/apps/internal/shared/shared.go +++ b/apps/internal/shared/shared.go @@ -70,3 +70,29 @@ func (acc Account) IsZero() bool { // DefaultClient is our default shared HTTP client. var DefaultClient = &http.Client{} + +type Prompt int64 + +const ( + PromptNone Prompt = iota + PromptLogin + PromptSelectAccount + PromptConsent + PromptCreate +) + +func (p Prompt) String() string { + switch p { + case PromptNone: + return "none" + case PromptLogin: + return "login" + case PromptSelectAccount: + return "select_account" + case PromptConsent: + return "consent" + case PromptCreate: + return "create" + } + return "" +} diff --git a/apps/public/public.go b/apps/public/public.go index 797c086c..97c61545 100644 --- a/apps/public/public.go +++ b/apps/public/public.go @@ -149,7 +149,7 @@ func New(clientID string, options ...Option) (Client, error) { // authCodeURLOptions contains options for AuthCodeURL type authCodeURLOptions struct { - claims, loginHint, tenantID, domainHint string + claims, loginHint, tenantID, domainHint, prompt string } // AuthCodeURLOption is implemented by options for AuthCodeURL @@ -159,7 +159,7 @@ type AuthCodeURLOption interface { // AuthCodeURL creates a URL used to acquire an authorization code. // -// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID] +// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID], [WithPrompt] func (pca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) { o := authCodeURLOptions{} if err := options.ApplyOptions(&o, opts); err != nil { @@ -172,6 +172,7 @@ func (pca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, ap.Claims = o.claims ap.LoginHint = o.loginHint ap.DomainHint = o.domainHint + ap.Prompt = o.prompt return pca.base.AuthCodeURL(ctx, clientID, redirectURI, scopes, ap) } @@ -526,9 +527,9 @@ func (pca Client) RemoveAccount(ctx context.Context, account Account) error { // interactiveAuthOptions contains the optional parameters used to acquire an access token for interactive auth code flow. type interactiveAuthOptions struct { - claims, domainHint, loginHint, redirectURI, tenantID string - openURL func(url string) error - authnScheme AuthenticationScheme + claims, domainHint, loginHint, redirectURI, tenantID, prompt string + openURL func(url string) error + authnScheme AuthenticationScheme } // AcquireInteractiveOption is implemented by options for AcquireTokenInteractive @@ -590,6 +591,33 @@ func WithDomainHint(domain string) interface { } } +// WithPrompt adds the IdP prompt query parameter in the auth url. +func WithPrompt(prompt shared.Prompt) interface { + AcquireInteractiveOption + AuthCodeURLOption + options.CallOption +} { + return struct { + AcquireInteractiveOption + AuthCodeURLOption + options.CallOption + }{ + CallOption: options.NewCallOption( + func(a any) error { + switch t := a.(type) { + case *authCodeURLOptions: + t.prompt = prompt.String() + case *interactiveAuthOptions: + t.prompt = prompt.String() + default: + return fmt.Errorf("unexpected options type %T", a) + } + return nil + }, + ), + } +} + // WithRedirectURI sets a port for the local server used in interactive authentication, for // example http://localhost:port. All URI components other than the port are ignored. func WithRedirectURI(redirectURI string) interface { @@ -674,7 +702,11 @@ func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string, authParams.LoginHint = o.loginHint authParams.DomainHint = o.domainHint authParams.State = uuid.New().String() - authParams.Prompt = "select_account" + if o.prompt != "" { + authParams.Prompt = o.prompt + } else { + authParams.Prompt = shared.PromptSelectAccount.String() + } if o.authnScheme != nil { authParams.AuthnScheme = o.authnScheme } diff --git a/apps/public/public_test.go b/apps/public/public_test.go index fa019ca5..f024a529 100644 --- a/apps/public/public_test.go +++ b/apps/public/public_test.go @@ -22,6 +22,7 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" "github.com/kylelemons/godebug/pretty" ) @@ -52,7 +53,7 @@ func fakeBrowserOpenURL(authURL string) error { } redirect := q.Get("redirect_uri") if redirect == "" { - return errors.New("missing query param 'redirect_uri'") + return errors.New("missing redirect param 'redirect_uri'") } // now send the info to our local redirect server resp, err := http.DefaultClient.Get(redirect + fmt.Sprintf("/?state=%s&code=fake_auth_code", state)) @@ -935,6 +936,78 @@ func TestWithDomainHint(t *testing.T) { } } +func TestWithPrompt(t *testing.T) { + prompt := shared.PromptCreate + client, err := New("client-id") + if err != nil { + t.Fatal(err) + } + client.base.Token.AccessTokens = &fake.AccessTokens{} + client.base.Token.Authority = &fake.Authority{} + client.base.Token.Resolver = &fake.ResolveEndpoints{} + for _, expectPrompt := range []bool{true, false} { + t.Run(fmt.Sprint(expectPrompt), func(t *testing.T) { + called := false + validate := func(v url.Values) error { + if !v.Has("prompt") { + if !expectPrompt { + return nil + } + return errors.New("expected a prompt") + } else if !expectPrompt { + return fmt.Errorf("expected no prompt, got %v", v["prompt"][0]) + } + + if actual := v["prompt"]; len(actual) != 1 || actual[0] != prompt.String() { + err = fmt.Errorf(`unexpected prompt "%v"`, actual[0]) + } + return err + } + browserOpenURL := func(authURL string) error { + called = true + parsed, err := url.Parse(authURL) + if err != nil { + return err + } + query, err := url.ParseQuery(parsed.RawQuery) + if err != nil { + return err + } + if err = validate(query); err != nil { + t.Fatal(err) + return err + } + // this helper validates the other params and completes the redirect + return fakeBrowserOpenURL(authURL) + } + acquireOpts := []AcquireInteractiveOption{WithOpenURL(browserOpenURL)} + var urlOpts []AuthCodeURLOption + if expectPrompt { + acquireOpts = append(acquireOpts, WithPrompt(prompt)) + urlOpts = append(urlOpts, WithPrompt(prompt)) + } + _, err = client.AcquireTokenInteractive(context.Background(), tokenScope, acquireOpts...) + if err != nil { + t.Fatal(err) + } + if !called { + t.Fatal("browserOpenURL wasn't called") + } + u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) + if err == nil { + var parsed *url.URL + parsed, err = url.Parse(u) + if err == nil { + err = validate(parsed.Query()) + } + } + if err != nil { + t.Fatal(err) + } + }) + } +} + func TestWithAuthenticationScheme(t *testing.T) { clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`)) lmo, tenant := "login.microsoftonline.com", "tenant"