Skip to content

Commit a36c563

Browse files
authored
Added support for storage managed disk challenge (Azure#20418)
1 parent 8ba0f80 commit a36c563

File tree

10 files changed

+253
-8
lines changed

10 files changed

+253
-8
lines changed

sdk/storage/azblob/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### Features Added
66

77
* Added [Blob Batch API](https://learn.microsoft.com/rest/api/storageservices/blob-batch).
8+
* Added support for bearer challenge for identity based managed disks.
89

910
### Breaking Changes
1011

sdk/storage/azblob/appendblob/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ type Client base.CompositeClient[generated.BlobClient, generated.AppendBlobClien
3535
// - cred - an Azure AD credential, typically obtained via the azidentity module
3636
// - options - client options; pass nil to accept the default values
3737
func NewClient(blobURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) {
38-
authPolicy := runtime.NewBearerTokenPolicy(cred, []string{shared.TokenScope}, nil)
38+
authPolicy := shared.NewStorageChallengePolicy(cred)
3939
conOptions := shared.GetClientOptions(options)
4040
conOptions.PerRetryPolicies = append(conOptions.PerRetryPolicies, authPolicy)
4141
pl := runtime.NewPipeline(exported.ModuleName,

sdk/storage/azblob/blob/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ type Client base.Client[generated.BlobClient]
3737
// - cred - an Azure AD credential, typically obtained via the azidentity module
3838
// - options - client options; pass nil to accept the default values
3939
func NewClient(blobURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) {
40-
authPolicy := runtime.NewBearerTokenPolicy(cred, []string{shared.TokenScope}, nil)
40+
authPolicy := shared.NewStorageChallengePolicy(cred)
4141
conOptions := shared.GetClientOptions(options)
4242
conOptions.PerRetryPolicies = append(conOptions.PerRetryPolicies, authPolicy)
4343
pl := runtime.NewPipeline(exported.ModuleName, exported.ModuleVersion, runtime.PipelineOptions{}, &conOptions.ClientOptions)

sdk/storage/azblob/blockblob/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ type Client base.CompositeClient[generated.BlobClient, generated.BlockBlobClient
4242
// - cred - an Azure AD credential, typically obtained via the azidentity module
4343
// - options - client options; pass nil to accept the default values
4444
func NewClient(blobURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) {
45-
authPolicy := runtime.NewBearerTokenPolicy(cred, []string{shared.TokenScope}, nil)
45+
authPolicy := shared.NewStorageChallengePolicy(cred)
4646
conOptions := shared.GetClientOptions(options)
4747
conOptions.PerRetryPolicies = append(conOptions.PerRetryPolicies, authPolicy)
4848
pl := runtime.NewPipeline(exported.ModuleName, exported.ModuleVersion, runtime.PipelineOptions{}, &conOptions.ClientOptions)

sdk/storage/azblob/container/client.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ type Client base.Client[generated.ContainerClient]
4444
// - cred - an Azure AD credential, typically obtained via the azidentity module
4545
// - options - client options; pass nil to accept the default values
4646
func NewClient(containerURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) {
47-
authPolicy := runtime.NewBearerTokenPolicy(cred, []string{shared.TokenScope}, nil)
47+
authPolicy := shared.NewStorageChallengePolicy(cred)
4848
conOptions := shared.GetClientOptions(options)
4949
conOptions.PerRetryPolicies = append(conOptions.PerRetryPolicies, authPolicy)
5050
pl := runtime.NewPipeline(exported.ModuleName, exported.ModuleVersion, runtime.PipelineOptions{}, &conOptions.ClientOptions)
@@ -351,7 +351,7 @@ func (c *Client) NewBatchBuilder() (*BatchBuilder, error) {
351351

352352
switch cred := c.credential().(type) {
353353
case *azcore.TokenCredential:
354-
authPolicy = runtime.NewBearerTokenPolicy(*cred, []string{shared.TokenScope}, nil)
354+
authPolicy = shared.NewStorageChallengePolicy(*cred)
355355
case *SharedKeyCredential:
356356
authPolicy = exported.NewSharedKeyCredPolicy(cred)
357357
case nil:
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
//go:build go1.18
2+
// +build go1.18
3+
4+
// Copyright (c) Microsoft Corporation. All rights reserved.
5+
// Licensed under the MIT License. See License.txt in the project root for license information.
6+
7+
package shared
8+
9+
import (
10+
"errors"
11+
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
12+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
13+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
14+
"net/http"
15+
"strings"
16+
)
17+
18+
type storageAuthorizer struct {
19+
scopes []string
20+
tenantID string
21+
}
22+
23+
func NewStorageChallengePolicy(cred azcore.TokenCredential) policy.Policy {
24+
s := storageAuthorizer{scopes: []string{TokenScope}}
25+
return runtime.NewBearerTokenPolicy(cred, []string{TokenScope}, &policy.BearerTokenOptions{
26+
AuthorizationHandler: policy.AuthorizationHandler{
27+
OnRequest: s.onRequest,
28+
OnChallenge: s.onChallenge,
29+
},
30+
})
31+
}
32+
33+
func (s *storageAuthorizer) onRequest(req *policy.Request, authNZ func(policy.TokenRequestOptions) error) error {
34+
if len(s.scopes) == 0 || s.tenantID == "" {
35+
// returning nil indicates the bearer token policy should send the request
36+
return nil
37+
}
38+
return authNZ(policy.TokenRequestOptions{Scopes: s.scopes})
39+
}
40+
41+
func (s *storageAuthorizer) onChallenge(req *policy.Request, resp *http.Response, authNZ func(policy.TokenRequestOptions) error) error {
42+
// parse the challenge
43+
err := s.parseChallenge(resp)
44+
if err != nil {
45+
return err
46+
}
47+
// TODO: Set tenantID when policy.TokenRequestOptions supports it. https://github.com/Azure/azure-sdk-for-go/issues/19841
48+
return authNZ(policy.TokenRequestOptions{Scopes: s.scopes})
49+
}
50+
51+
type challengePolicyError struct {
52+
err error
53+
}
54+
55+
func (c *challengePolicyError) Error() string {
56+
return c.err.Error()
57+
}
58+
59+
func (*challengePolicyError) NonRetriable() {
60+
// marker method
61+
}
62+
63+
func (c *challengePolicyError) Unwrap() error {
64+
return c.err
65+
}
66+
67+
// parses Tenant ID from auth challenge
68+
// https://login.microsoftonline.com/00000000-0000-0000-0000-000000000000/oauth2/authorize
69+
func parseTenant(url string) string {
70+
if url == "" {
71+
return ""
72+
}
73+
parts := strings.Split(url, "/")
74+
if len(parts) >= 3 {
75+
tenant := parts[3]
76+
tenant = strings.ReplaceAll(tenant, ",", "")
77+
return tenant
78+
} else {
79+
return ""
80+
}
81+
}
82+
83+
func (s *storageAuthorizer) parseChallenge(resp *http.Response) error {
84+
authHeader := resp.Header.Get("WWW-Authenticate")
85+
if authHeader == "" {
86+
return &challengePolicyError{err: errors.New("response has no WWW-Authenticate header for challenge authentication")}
87+
}
88+
89+
// Strip down to auth and resource
90+
// Format is "Bearer authorization_uri=\"<site>\" resource_id=\"<site>\""
91+
authHeader = strings.ReplaceAll(authHeader, "Bearer ", "")
92+
93+
parts := strings.Split(authHeader, " ")
94+
95+
vals := map[string]string{}
96+
for _, part := range parts {
97+
subParts := strings.Split(part, "=")
98+
if len(subParts) == 2 {
99+
stripped := strings.ReplaceAll(subParts[1], "\"", "")
100+
stripped = strings.TrimSuffix(stripped, ",")
101+
vals[subParts[0]] = stripped
102+
}
103+
}
104+
105+
s.tenantID = parseTenant(vals["authorization_uri"])
106+
107+
scope := vals["resource_id"]
108+
if scope == "" {
109+
return &challengePolicyError{err: errors.New("could not find a valid resource in the WWW-Authenticate header")}
110+
}
111+
112+
if !strings.HasSuffix(scope, "/.default") {
113+
scope += "/.default"
114+
}
115+
s.scopes = []string{scope}
116+
return nil
117+
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
//go:build go1.18
2+
// +build go1.18
3+
4+
// Copyright (c) Microsoft Corporation. All rights reserved.
5+
// Licensed under the MIT License. See License.txt in the project root for license information.
6+
7+
package shared
8+
9+
import (
10+
"context"
11+
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
12+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
13+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
14+
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
15+
"github.com/stretchr/testify/require"
16+
"net/http"
17+
"strings"
18+
"testing"
19+
"time"
20+
)
21+
22+
type credentialFunc func(context.Context, policy.TokenRequestOptions) (azcore.AccessToken, error)
23+
24+
func (cf credentialFunc) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) {
25+
return cf(ctx, options)
26+
}
27+
28+
func TestChallengePolicy(t *testing.T) {
29+
accessToken := "***"
30+
storageResource := "https://storage.azure.com"
31+
storageScope := "https://storage.azure.com/.default"
32+
challenge := `Bearer authorization_uri="https://login.microsoftonline.com/{tenant}", resource_id="{storageResource}"`
33+
diskResource := "https://disk.azure.com/"
34+
diskScope := "https://disk.azure.com//.default"
35+
36+
for _, test := range []struct {
37+
expectedScope, format, resource string
38+
}{
39+
{format: challenge, resource: storageResource, expectedScope: storageScope},
40+
{format: challenge, resource: diskResource, expectedScope: diskScope},
41+
} {
42+
t.Run("", func(t *testing.T) {
43+
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
44+
defer close()
45+
srv.AppendResponse(
46+
mock.WithHeader("WWW-Authenticate", strings.ReplaceAll(test.format, "{storageResource}", test.resource)),
47+
mock.WithStatusCode(401),
48+
)
49+
srv.AppendResponse(mock.WithPredicate(func(r *http.Request) bool {
50+
if authz := r.Header.Values("Authorization"); len(authz) != 1 || authz[0] != "Bearer "+accessToken {
51+
t.Errorf(`unexpected Authorization "%s"`, authz)
52+
}
53+
return true
54+
}))
55+
srv.AppendResponse()
56+
authenticated := false
57+
cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
58+
authenticated = true
59+
require.Equal(t, []string{test.expectedScope}, tro.Scopes)
60+
return azcore.AccessToken{Token: accessToken, ExpiresOn: time.Now().Add(time.Hour)}, nil
61+
})
62+
p := NewStorageChallengePolicy(cred)
63+
pl := runtime.NewPipeline("", "",
64+
runtime.PipelineOptions{PerRetry: []policy.Policy{p}},
65+
&policy.ClientOptions{Transport: srv},
66+
)
67+
req, err := runtime.NewRequest(context.Background(), "GET", "https://localhost")
68+
require.NoError(t, err)
69+
_, err = pl.Do(req)
70+
require.NoError(t, err)
71+
require.True(t, authenticated, "policy should have authenticated")
72+
})
73+
}
74+
}
75+
76+
func TestParseTenant(t *testing.T) {
77+
actual := parseTenant("")
78+
require.Empty(t, actual)
79+
80+
expected := "00000000-0000-0000-0000-000000000000"
81+
sampleURL := "https://login.microsoftonline.com/" + expected
82+
actual = parseTenant(sampleURL)
83+
require.Equal(t, expected, actual, "tenant was not properly parsed")
84+
}
85+
86+
func TestParseTenantNegative(t *testing.T) {
87+
actual := parseTenant("")
88+
require.Empty(t, actual)
89+
90+
expected := ""
91+
sampleURL := "https://login.microsoftonline.com/" + expected
92+
actual = parseTenant(sampleURL)
93+
require.Equal(t, expected, actual)
94+
95+
sampleURL = ""
96+
actual = parseTenant(sampleURL)
97+
require.Equal(t, expected, actual)
98+
}

sdk/storage/azblob/pageblob/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ type Client base.CompositeClient[generated.BlobClient, generated.PageBlobClient]
3737
// - cred - an Azure AD credential, typically obtained via the azidentity module
3838
// - options - client options; pass nil to accept the default values
3939
func NewClient(blobURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) {
40-
authPolicy := runtime.NewBearerTokenPolicy(cred, []string{shared.TokenScope}, nil)
40+
authPolicy := shared.NewStorageChallengePolicy(cred)
4141
conOptions := shared.GetClientOptions(options)
4242
conOptions.PerRetryPolicies = append(conOptions.PerRetryPolicies, authPolicy)
4343
pl := runtime.NewPipeline(exported.ModuleName, exported.ModuleVersion, runtime.PipelineOptions{}, &conOptions.ClientOptions)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//go:build go1.18
2+
// +build go1.18
3+
4+
// Copyright (c) Microsoft Corporation. All rights reserved.
5+
// Licensed under the MIT License. See License.txt in the project root for license information.
6+
7+
package pageblob_test
8+
9+
// This test checks the storage challenge policy.
10+
func (s *PageBlobUnrecordedTestsSuite) TestManagedDiskOAuth() {
11+
//_require := require.New(s.T())
12+
//
13+
//// Set up for this test.
14+
//// In Azure Portal create a managed disk.
15+
//// Under Access Control (IAM), ensure the "Data Operator for managed disks" role is added.
16+
//// Under Disk Export, check Enable data access authentication mode.
17+
//// Click on Generate URL and paste that URL below as urlWithSas.
18+
//cred, err := azidentity.NewDefaultAzureCredential(nil)
19+
//_require.NoError(err)
20+
//
21+
//urlWithSas := "https://md-XXXXX.blob.core.windows.net/XXXX/XXXX?sv=2018-03-28&sr=b&si=XXXXXXX&sig=XXXXXXXXXX"
22+
//
23+
//// Create a page blob client with OAuth
24+
//blobClient, err := pageblob.NewClient(urlWithSas, cred, nil)
25+
//_require.NoError(err)
26+
//
27+
//_, err = blobClient.GetProperties(context.TODO(), nil)
28+
//_require.NoError(err)
29+
}

sdk/storage/azblob/service/client.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ type Client base.Client[generated.ServiceClient]
4242
// - cred - an Azure AD credential, typically obtained via the azidentity module
4343
// - options - client options; pass nil to accept the default values
4444
func NewClient(serviceURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) {
45-
authPolicy := runtime.NewBearerTokenPolicy(cred, []string{shared.TokenScope}, nil)
45+
authPolicy := shared.NewStorageChallengePolicy(cred)
4646
conOptions := shared.GetClientOptions(options)
4747
conOptions.PerRetryPolicies = append(conOptions.PerRetryPolicies, authPolicy)
4848
pl := runtime.NewPipeline(exported.ModuleName, exported.ModuleVersion, runtime.PipelineOptions{}, &conOptions.ClientOptions)
@@ -303,7 +303,7 @@ func (s *Client) NewBatchBuilder() (*BatchBuilder, error) {
303303

304304
switch cred := s.credential().(type) {
305305
case *azcore.TokenCredential:
306-
authPolicy = runtime.NewBearerTokenPolicy(*cred, []string{shared.TokenScope}, nil)
306+
authPolicy = shared.NewStorageChallengePolicy(*cred)
307307
case *SharedKeyCredential:
308308
authPolicy = exported.NewSharedKeyCredPolicy(cred)
309309
case nil:

0 commit comments

Comments
 (0)