Skip to content

Commit 560c032

Browse files
authored
Merge pull request #37 from Arize-ai/amunoz/azure-managed-identity
add managed identity for azure store auth
2 parents 7300baf + c740ca0 commit 560c032

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

broker/stores/azure/ad.go

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"sync"
1010
"time"
1111

12+
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
1213
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
1314
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
1415
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
@@ -36,6 +37,9 @@ type adStore struct {
3637
}
3738

3839
// NewAD creates a new Azure AD authenticated Store from the provided URL.
40+
// Authentication: if AZURE_CLIENT_ID and AZURE_CLIENT_SECRET are both set, client secret is used.
41+
// Otherwise DefaultAzureCredential is used, which supports workload identity (e.g. in Kubernetes
42+
// with AZURE_CLIENT_ID, AZURE_TENANT_ID, AZURE_FEDERATED_TOKEN_FILE), managed identity, and Azure CLI.
3943
func NewAD(ep *url.URL) (stores.Store, error) {
4044
var args StoreQueryArgs
4145

@@ -53,31 +57,40 @@ func NewAD(ep *url.URL) (stores.Store, error) {
5357
var container = path[1]
5458
var prefix = strings.Join(path[2:], "/")
5559

56-
var clientID = os.Getenv("AZURE_CLIENT_ID")
57-
var clientSecret = os.Getenv("AZURE_CLIENT_SECRET")
58-
59-
if clientID == "" || clientSecret == "" {
60-
return nil, fmt.Errorf("AZURE_CLIENT_ID and AZURE_CLIENT_SECRET must be set for azure-ad:// URLs")
61-
}
62-
6360
// arize change to support china cloud
6461
blobDomain := os.Getenv("AZURE_BLOB_DOMAIN")
6562
if blobDomain == "" {
6663
blobDomain = "blob.core.windows.net"
6764
}
6865

69-
var credentials, err = azidentity.NewClientSecretCredential(
70-
tenantID,
71-
clientID,
72-
clientSecret,
73-
&azidentity.ClientSecretCredentialOptions{
66+
var credentials azcore.TokenCredential
67+
var err error
68+
var clientID = os.Getenv("AZURE_CLIENT_ID")
69+
var clientSecret = os.Getenv("AZURE_CLIENT_SECRET")
70+
if clientID != "" && clientSecret != "" {
71+
credentials, err = azidentity.NewClientSecretCredential(
72+
tenantID,
73+
clientID,
74+
clientSecret,
75+
&azidentity.ClientSecretCredentialOptions{
76+
DisableInstanceDiscovery: true,
77+
},
78+
)
79+
} else {
80+
credentials, err = azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{
81+
TenantID: tenantID,
7482
DisableInstanceDiscovery: true,
75-
},
76-
)
83+
})
84+
}
7785
if err != nil {
7886
return nil, err
7987
}
8088

89+
var authMethod = "workload identity / default chain"
90+
if clientID != "" && clientSecret != "" {
91+
authMethod = "client secret"
92+
}
93+
8194
var refreshFn = func(credential azblob.TokenCredential) time.Duration {
8295
if token, err := credentials.GetToken(
8396
context.Background(),
@@ -126,6 +139,7 @@ func NewAD(ep *url.URL) (stores.Store, error) {
126139
"blobDomain": blobDomain,
127140
"container": container,
128141
"prefix": prefix,
142+
"auth": authMethod,
129143
}).Info("constructed new Azure AD storage client")
130144

131145
return store, nil

0 commit comments

Comments
 (0)