Skip to content

Commit fd8d6b1

Browse files
szogoon4gust
andauthored
feat: add WithPrompt option for AuthCodeURL (#585)
* feat: add WithPrompt option for AuthCodeURL * chore: clean print statements, and use enum for prompt * chore: add WithPrompt to interactive flow * chore: revert to original error message in fakeBrowserOpenURL Co-authored-by: Nilesh Choudhary <107404295+4gust@users.noreply.github.com> * chore: set select_account as default prompt for AcquireTokenInteractive --------- Co-authored-by: Nilesh Choudhary <107404295+4gust@users.noreply.github.com>
1 parent 30f0f89 commit fd8d6b1

File tree

5 files changed

+216
-9
lines changed

5 files changed

+216
-9
lines changed

apps/confidential/confidential.go

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ func New(authority, clientID string, cred Credential, options ...Option) (Client
359359

360360
// authCodeURLOptions contains options for AuthCodeURL
361361
type authCodeURLOptions struct {
362-
claims, loginHint, tenantID, domainHint string
362+
claims, loginHint, tenantID, domainHint, prompt string
363363
}
364364

365365
// AuthCodeURLOption is implemented by options for AuthCodeURL
@@ -369,7 +369,7 @@ type AuthCodeURLOption interface {
369369

370370
// AuthCodeURL creates a URL used to acquire an authorization code. Users need to call CreateAuthorizationCodeURLParameters and pass it in.
371371
//
372-
// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID]
372+
// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID], [WithPrompt]
373373
func (cca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) {
374374
o := authCodeURLOptions{}
375375
if err := options.ApplyOptions(&o, opts); err != nil {
@@ -382,6 +382,7 @@ func (cca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string,
382382
ap.Claims = o.claims
383383
ap.LoginHint = o.loginHint
384384
ap.DomainHint = o.domainHint
385+
ap.Prompt = o.prompt
385386
return cca.base.AuthCodeURL(ctx, clientID, redirectURI, scopes, ap)
386387
}
387388

@@ -431,6 +432,29 @@ func WithDomainHint(domain string) interface {
431432
}
432433
}
433434

435+
// WithPrompt adds prompt query parameter in the auth url.
436+
func WithPrompt(prompt shared.Prompt) interface {
437+
AuthCodeURLOption
438+
options.CallOption
439+
} {
440+
return struct {
441+
AuthCodeURLOption
442+
options.CallOption
443+
}{
444+
CallOption: options.NewCallOption(
445+
func(a any) error {
446+
switch t := a.(type) {
447+
case *authCodeURLOptions:
448+
t.prompt = prompt.String()
449+
default:
450+
return fmt.Errorf("unexpected options type %T", a)
451+
}
452+
return nil
453+
},
454+
),
455+
}
456+
}
457+
434458
// WithClaims sets additional claims to request for the token, such as those required by conditional access policies.
435459
// Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded.
436460
// This option is valid for any token acquisition method.

apps/confidential/confidential_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2248,6 +2248,58 @@ func TestWithDomainHint(t *testing.T) {
22482248
}
22492249
}
22502250

2251+
func TestWithPrompt(t *testing.T) {
2252+
prompt := shared.PromptLogin
2253+
cred, err := NewCredFromSecret(fakeSecret)
2254+
if err != nil {
2255+
t.Fatal(err)
2256+
}
2257+
client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{}))
2258+
if err != nil {
2259+
t.Fatal(err)
2260+
}
2261+
if err != nil {
2262+
t.Fatal(err)
2263+
}
2264+
client.base.Token.AccessTokens = &fake.AccessTokens{}
2265+
client.base.Token.Authority = &fake.Authority{}
2266+
client.base.Token.Resolver = &fake.ResolveEndpoints{}
2267+
for _, expectPrompt := range []bool{true, false} {
2268+
t.Run(fmt.Sprint(expectPrompt), func(t *testing.T) {
2269+
validate := func(v url.Values) error {
2270+
if !v.Has("prompt") {
2271+
if !expectPrompt {
2272+
return nil
2273+
}
2274+
return errors.New("expected a prompt")
2275+
} else if !expectPrompt {
2276+
return fmt.Errorf("expected no prompt, got %v", v["prompt"][0])
2277+
}
2278+
2279+
if actual := v["prompt"]; len(actual) != 1 || actual[0] != prompt.String() {
2280+
err = fmt.Errorf(`unexpected prompt "%v"`, actual[0])
2281+
}
2282+
return err
2283+
}
2284+
var urlOpts []AuthCodeURLOption
2285+
if expectPrompt {
2286+
urlOpts = append(urlOpts, WithPrompt(prompt))
2287+
}
2288+
u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...)
2289+
if err == nil {
2290+
var parsed *url.URL
2291+
parsed, err = url.Parse(u)
2292+
if err == nil {
2293+
err = validate(parsed.Query())
2294+
}
2295+
}
2296+
if err != nil {
2297+
t.Fatal(err)
2298+
}
2299+
})
2300+
}
2301+
}
2302+
22512303
func TestWithAuthenticationScheme(t *testing.T) {
22522304
ctx := context.Background()
22532305
authScheme := mock.NewTestAuthnScheme()

apps/internal/shared/shared.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,29 @@ func (acc Account) IsZero() bool {
7070

7171
// DefaultClient is our default shared HTTP client.
7272
var DefaultClient = &http.Client{}
73+
74+
type Prompt int64
75+
76+
const (
77+
PromptNone Prompt = iota
78+
PromptLogin
79+
PromptSelectAccount
80+
PromptConsent
81+
PromptCreate
82+
)
83+
84+
func (p Prompt) String() string {
85+
switch p {
86+
case PromptNone:
87+
return "none"
88+
case PromptLogin:
89+
return "login"
90+
case PromptSelectAccount:
91+
return "select_account"
92+
case PromptConsent:
93+
return "consent"
94+
case PromptCreate:
95+
return "create"
96+
}
97+
return ""
98+
}

apps/public/public.go

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func New(clientID string, options ...Option) (Client, error) {
149149

150150
// authCodeURLOptions contains options for AuthCodeURL
151151
type authCodeURLOptions struct {
152-
claims, loginHint, tenantID, domainHint string
152+
claims, loginHint, tenantID, domainHint, prompt string
153153
}
154154

155155
// AuthCodeURLOption is implemented by options for AuthCodeURL
@@ -159,7 +159,7 @@ type AuthCodeURLOption interface {
159159

160160
// AuthCodeURL creates a URL used to acquire an authorization code.
161161
//
162-
// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID]
162+
// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID], [WithPrompt]
163163
func (pca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) {
164164
o := authCodeURLOptions{}
165165
if err := options.ApplyOptions(&o, opts); err != nil {
@@ -172,6 +172,7 @@ func (pca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string,
172172
ap.Claims = o.claims
173173
ap.LoginHint = o.loginHint
174174
ap.DomainHint = o.domainHint
175+
ap.Prompt = o.prompt
175176
return pca.base.AuthCodeURL(ctx, clientID, redirectURI, scopes, ap)
176177
}
177178

@@ -526,9 +527,9 @@ func (pca Client) RemoveAccount(ctx context.Context, account Account) error {
526527

527528
// interactiveAuthOptions contains the optional parameters used to acquire an access token for interactive auth code flow.
528529
type interactiveAuthOptions struct {
529-
claims, domainHint, loginHint, redirectURI, tenantID string
530-
openURL func(url string) error
531-
authnScheme AuthenticationScheme
530+
claims, domainHint, loginHint, redirectURI, tenantID, prompt string
531+
openURL func(url string) error
532+
authnScheme AuthenticationScheme
532533
}
533534

534535
// AcquireInteractiveOption is implemented by options for AcquireTokenInteractive
@@ -590,6 +591,33 @@ func WithDomainHint(domain string) interface {
590591
}
591592
}
592593

594+
// WithPrompt adds the IdP prompt query parameter in the auth url.
595+
func WithPrompt(prompt shared.Prompt) interface {
596+
AcquireInteractiveOption
597+
AuthCodeURLOption
598+
options.CallOption
599+
} {
600+
return struct {
601+
AcquireInteractiveOption
602+
AuthCodeURLOption
603+
options.CallOption
604+
}{
605+
CallOption: options.NewCallOption(
606+
func(a any) error {
607+
switch t := a.(type) {
608+
case *authCodeURLOptions:
609+
t.prompt = prompt.String()
610+
case *interactiveAuthOptions:
611+
t.prompt = prompt.String()
612+
default:
613+
return fmt.Errorf("unexpected options type %T", a)
614+
}
615+
return nil
616+
},
617+
),
618+
}
619+
}
620+
593621
// WithRedirectURI sets a port for the local server used in interactive authentication, for
594622
// example http://localhost:port. All URI components other than the port are ignored.
595623
func WithRedirectURI(redirectURI string) interface {
@@ -674,7 +702,11 @@ func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string,
674702
authParams.LoginHint = o.loginHint
675703
authParams.DomainHint = o.domainHint
676704
authParams.State = uuid.New().String()
677-
authParams.Prompt = "select_account"
705+
if o.prompt != "" {
706+
authParams.Prompt = o.prompt
707+
} else {
708+
authParams.Prompt = shared.PromptSelectAccount.String()
709+
}
678710
if o.authnScheme != nil {
679711
authParams.AuthnScheme = o.authnScheme
680712
}

apps/public/public_test.go

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
2424
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
2525
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
26+
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
2627
"github.com/kylelemons/godebug/pretty"
2728
)
2829

@@ -53,7 +54,7 @@ func fakeBrowserOpenURL(authURL string) error {
5354
}
5455
redirect := q.Get("redirect_uri")
5556
if redirect == "" {
56-
return errors.New("missing query param 'redirect_uri'")
57+
return errors.New("missing redirect param 'redirect_uri'")
5758
}
5859
// now send the info to our local redirect server
5960
resp, err := http.DefaultClient.Get(redirect + fmt.Sprintf("/?state=%s&code=fake_auth_code", state))
@@ -936,6 +937,78 @@ func TestWithDomainHint(t *testing.T) {
936937
}
937938
}
938939

940+
func TestWithPrompt(t *testing.T) {
941+
prompt := shared.PromptCreate
942+
client, err := New("client-id")
943+
if err != nil {
944+
t.Fatal(err)
945+
}
946+
client.base.Token.AccessTokens = &fake.AccessTokens{}
947+
client.base.Token.Authority = &fake.Authority{}
948+
client.base.Token.Resolver = &fake.ResolveEndpoints{}
949+
for _, expectPrompt := range []bool{true, false} {
950+
t.Run(fmt.Sprint(expectPrompt), func(t *testing.T) {
951+
called := false
952+
validate := func(v url.Values) error {
953+
if !v.Has("prompt") {
954+
if !expectPrompt {
955+
return nil
956+
}
957+
return errors.New("expected a prompt")
958+
} else if !expectPrompt {
959+
return fmt.Errorf("expected no prompt, got %v", v["prompt"][0])
960+
}
961+
962+
if actual := v["prompt"]; len(actual) != 1 || actual[0] != prompt.String() {
963+
err = fmt.Errorf(`unexpected prompt "%v"`, actual[0])
964+
}
965+
return err
966+
}
967+
browserOpenURL := func(authURL string) error {
968+
called = true
969+
parsed, err := url.Parse(authURL)
970+
if err != nil {
971+
return err
972+
}
973+
query, err := url.ParseQuery(parsed.RawQuery)
974+
if err != nil {
975+
return err
976+
}
977+
if err = validate(query); err != nil {
978+
t.Fatal(err)
979+
return err
980+
}
981+
// this helper validates the other params and completes the redirect
982+
return fakeBrowserOpenURL(authURL)
983+
}
984+
acquireOpts := []AcquireInteractiveOption{WithOpenURL(browserOpenURL)}
985+
var urlOpts []AuthCodeURLOption
986+
if expectPrompt {
987+
acquireOpts = append(acquireOpts, WithPrompt(prompt))
988+
urlOpts = append(urlOpts, WithPrompt(prompt))
989+
}
990+
_, err = client.AcquireTokenInteractive(context.Background(), tokenScope, acquireOpts...)
991+
if err != nil {
992+
t.Fatal(err)
993+
}
994+
if !called {
995+
t.Fatal("browserOpenURL wasn't called")
996+
}
997+
u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...)
998+
if err == nil {
999+
var parsed *url.URL
1000+
parsed, err = url.Parse(u)
1001+
if err == nil {
1002+
err = validate(parsed.Query())
1003+
}
1004+
}
1005+
if err != nil {
1006+
t.Fatal(err)
1007+
}
1008+
})
1009+
}
1010+
}
1011+
9391012
func TestWithAuthenticationScheme(t *testing.T) {
9401013
clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`))
9411014
lmo, tenant := "login.microsoftonline.com", "tenant"

0 commit comments

Comments
 (0)