Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions apps/confidential/confidential.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ func New(authority, clientID string, cred Credential, options ...Option) (Client

// authCodeURLOptions contains options for AuthCodeURL
type authCodeURLOptions struct {
claims, loginHint, tenantID, domainHint string
claims, loginHint, tenantID, domainHint, prompt string
}

// AuthCodeURLOption is implemented by options for AuthCodeURL
Expand All @@ -369,7 +369,7 @@ type AuthCodeURLOption interface {

// AuthCodeURL creates a URL used to acquire an authorization code. Users need to call CreateAuthorizationCodeURLParameters and pass it in.
//
// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID]
// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID], [WithPrompt]
func (cca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) {
o := authCodeURLOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
Expand All @@ -382,6 +382,7 @@ func (cca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string,
ap.Claims = o.claims
ap.LoginHint = o.loginHint
ap.DomainHint = o.domainHint
ap.Prompt = o.prompt
return cca.base.AuthCodeURL(ctx, clientID, redirectURI, scopes, ap)
}

Expand Down Expand Up @@ -431,6 +432,29 @@ func WithDomainHint(domain string) interface {
}
}

// WithPrompt adds prompt query parameter in the auth url.
func WithPrompt(prompt shared.Prompt) interface {
AuthCodeURLOption
options.CallOption
} {
return struct {
AuthCodeURLOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *authCodeURLOptions:
t.prompt = prompt.String()
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}

// WithClaims sets additional claims to request for the token, such as those required by conditional access policies.
// Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded.
// This option is valid for any token acquisition method.
Expand Down
53 changes: 53 additions & 0 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
)

// errorClient is an HTTP client for tests that should fail when confidential.Client sends a request
Expand Down Expand Up @@ -1774,6 +1775,58 @@ func TestWithDomainHint(t *testing.T) {
}
}

func TestWithPrompt(t *testing.T) {
prompt := shared.PromptLogin
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{}))
if err != nil {
t.Fatal(err)
}
if err != nil {
t.Fatal(err)
}
client.base.Token.AccessTokens = &fake.AccessTokens{}
client.base.Token.Authority = &fake.Authority{}
client.base.Token.Resolver = &fake.ResolveEndpoints{}
for _, expectPrompt := range []bool{true, false} {
t.Run(fmt.Sprint(expectPrompt), func(t *testing.T) {
validate := func(v url.Values) error {
if !v.Has("prompt") {
if !expectPrompt {
return nil
}
return errors.New("expected a prompt")
} else if !expectPrompt {
return fmt.Errorf("expected no prompt, got %v", v["prompt"][0])
}

if actual := v["prompt"]; len(actual) != 1 || actual[0] != prompt.String() {
err = fmt.Errorf(`unexpected prompt "%v"`, actual[0])
}
return err
}
var urlOpts []AuthCodeURLOption
if expectPrompt {
urlOpts = append(urlOpts, WithPrompt(prompt))
}
u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...)
if err == nil {
var parsed *url.URL
parsed, err = url.Parse(u)
if err == nil {
err = validate(parsed.Query())
}
}
if err != nil {
t.Fatal(err)
}
})
}
}

func TestWithAuthenticationScheme(t *testing.T) {
ctx := context.Background()
authScheme := mock.NewTestAuthnScheme()
Expand Down
26 changes: 26 additions & 0 deletions apps/internal/shared/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,29 @@ func (acc Account) IsZero() bool {

// DefaultClient is our default shared HTTP client.
var DefaultClient = &http.Client{}

type Prompt int64

const (
PromptNone Prompt = iota
PromptLogin
PromptSelectAccount
PromptConsent
PromptCreate
)

func (p Prompt) String() string {
switch p {
case PromptNone:
return "none"
case PromptLogin:
return "login"
case PromptSelectAccount:
return "select_account"
case PromptConsent:
return "consent"
case PromptCreate:
return "create"
}
return ""
}
34 changes: 29 additions & 5 deletions apps/public/public.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func New(clientID string, options ...Option) (Client, error) {

// authCodeURLOptions contains options for AuthCodeURL
type authCodeURLOptions struct {
claims, loginHint, tenantID, domainHint string
claims, loginHint, tenantID, domainHint, prompt string
}

// AuthCodeURLOption is implemented by options for AuthCodeURL
Expand All @@ -159,7 +159,7 @@ type AuthCodeURLOption interface {

// AuthCodeURL creates a URL used to acquire an authorization code.
//
// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID]
// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID], [WithPrompt]
func (pca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) {
o := authCodeURLOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
Expand All @@ -172,6 +172,7 @@ func (pca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string,
ap.Claims = o.claims
ap.LoginHint = o.loginHint
ap.DomainHint = o.domainHint
ap.Prompt = o.prompt
return pca.base.AuthCodeURL(ctx, clientID, redirectURI, scopes, ap)
}

Expand Down Expand Up @@ -526,9 +527,9 @@ func (pca Client) RemoveAccount(ctx context.Context, account Account) error {

// interactiveAuthOptions contains the optional parameters used to acquire an access token for interactive auth code flow.
type interactiveAuthOptions struct {
claims, domainHint, loginHint, redirectURI, tenantID string
openURL func(url string) error
authnScheme AuthenticationScheme
claims, domainHint, loginHint, redirectURI, tenantID, prompt string
openURL func(url string) error
authnScheme AuthenticationScheme
}

// AcquireInteractiveOption is implemented by options for AcquireTokenInteractive
Expand Down Expand Up @@ -590,6 +591,29 @@ func WithDomainHint(domain string) interface {
}
}

// WithPrompt adds the IdP prompt query parameter in the auth url.
func WithPrompt(prompt shared.Prompt) interface {
AuthCodeURLOption
options.CallOption
} {
return struct {
AuthCodeURLOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *authCodeURLOptions:
t.prompt = prompt.String()
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}

// WithRedirectURI sets a port for the local server used in interactive authentication, for
// example http://localhost:port. All URI components other than the port are ignored.
func WithRedirectURI(redirectURI string) interface {
Expand Down
46 changes: 46 additions & 0 deletions apps/public/public_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
"github.com/kylelemons/godebug/pretty"
)

Expand Down Expand Up @@ -935,6 +936,51 @@ func TestWithDomainHint(t *testing.T) {
}
}

func TestWithPrompt(t *testing.T) {
prompt := shared.PromptSelectAccount
client, err := New("client-id")
if err != nil {
t.Fatal(err)
}
client.base.Token.AccessTokens = &fake.AccessTokens{}
client.base.Token.Authority = &fake.Authority{}
client.base.Token.Resolver = &fake.ResolveEndpoints{}
for _, expectPrompt := range []bool{true, false} {
t.Run(fmt.Sprint(expectPrompt), func(t *testing.T) {
validate := func(v url.Values) error {
if !v.Has("prompt") {
if !expectPrompt {
return nil
}
return errors.New("expected a prompt")
} else if !expectPrompt {
return fmt.Errorf("expected no prompt, got %v", v["prompt"][0])
}

if actual := v["prompt"]; len(actual) != 1 || actual[0] != prompt.String() {
err = fmt.Errorf(`unexpected prompt "%v"`, actual[0])
}
return err
}
var urlOpts []AuthCodeURLOption
if expectPrompt {
urlOpts = append(urlOpts, WithPrompt(prompt))
}
u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...)
if err == nil {
var parsed *url.URL
parsed, err = url.Parse(u)
if err == nil {
err = validate(parsed.Query())
}
}
if err != nil {
t.Fatal(err)
}
})
}
}

func TestWithAuthenticationScheme(t *testing.T) {
clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`))
lmo, tenant := "login.microsoftonline.com", "tenant"
Expand Down