Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions apps/public/public.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func New(clientID string, options ...Option) (Client, error) {

// authCodeURLOptions contains options for AuthCodeURL
type authCodeURLOptions struct {
claims, loginHint, tenantID, domainHint string
claims, loginHint, tenantID, domainHint, state string
}

// AuthCodeURLOption is implemented by options for AuthCodeURL
Expand All @@ -152,7 +152,7 @@ type AuthCodeURLOption interface {

// AuthCodeURL creates a URL used to acquire an authorization code.
//
// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID]
// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID], [WithState]
func (pca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) {
o := authCodeURLOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
Expand All @@ -162,12 +162,36 @@ func (pca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string,
if err != nil {
return "", err
}
ap.State = o.state
ap.Claims = o.claims
ap.LoginHint = o.loginHint
ap.DomainHint = o.domainHint
return pca.base.AuthCodeURL(ctx, clientID, redirectURI, scopes, ap)
}

// WithState adds a user-generated state to the request.
func WithState(state string) interface {
AuthCodeURLOption
options.CallOption
} {
return struct {
AuthCodeURLOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *authCodeURLOptions:
t.state = state
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}

// WithClaims sets additional claims to request for the token, such as those required by conditional access policies.
// Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded.
// This option is valid for any token acquisition method.
Expand Down
38 changes: 38 additions & 0 deletions apps/public/public_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,44 @@ func TestWithDomainHint(t *testing.T) {
}
}

func TestWithState(t *testing.T) {
state := "abc-123-secure-string"
client, err := New("client-id")
if err != nil {
t.Fatal(err)
}
client.base.Token.AccessTokens = &fake.AccessTokens{}
client.base.Token.Authority = &fake.Authority{}
client.base.Token.Resolver = &fake.ResolveEndpoints{}
for _, expectHint := range []bool{true, false} {
t.Run(fmt.Sprint(expectHint), func(t *testing.T) {
var urlOpts []AuthCodeURLOption
if expectHint {
urlOpts = append(urlOpts, WithState(state))
}
u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...)
if err != nil {
t.Fatal(err)
}
parsed, err := url.Parse(u)
if err != nil {
t.Fatal(err)
}
if !parsed.Query().Has("state") {
if !expectHint {
return
}
t.Fatal("expected a state")
} else if !expectHint {
t.Fatal("expected no state")
}
if actual := parsed.Query()["state"]; len(actual) != 1 || actual[0] != state {
t.Fatalf(`unexpected state "%v"`, actual)
}
})
}
}

func TestWithAuthenticationScheme(t *testing.T) {
clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`))
lmo, tenant := "login.microsoftonline.com", "tenant"
Expand Down