Skip to content

Commit 1b9c0ee

Browse files
Reapply "update AzureBaseHook to return credentials that supports get_token method" (apache#56228)
* all changes and fixed * more fix --------- Co-authored-by: Karun Poudel <64540927+karunpoudel-chr@users.noreply.github.com>
1 parent dcd9f8b commit 1b9c0ee

File tree

4 files changed

+186
-17
lines changed

4 files changed

+186
-17
lines changed

docs/spelling_wordlist.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ classpaths
272272
cleartext
273273
cli
274274
clientId
275+
ClientSecretCredential
275276
cloudant
276277
CloudantV
277278
cloudbuild
@@ -476,6 +477,7 @@ deduplicated
476477
deduplication
477478
deepcopy
478479
deepcopying
480+
DefaultAzureCredential
479481
deferrable
480482
deidentify
481483
DeidentifyContentResponse
@@ -1612,6 +1614,7 @@ serializer
16121614
serializers
16131615
serverless
16141616
ServiceAccount
1617+
ServicePrincipalCredentials
16151618
ServiceResource
16161619
SES
16171620
sessionmaker

providers/microsoft/azure/docs/connections/azure.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ Extra (optional)
7474
It specifies the json that contains the authentication information.
7575
* ``managed_identity_client_id``: The client ID of a user-assigned managed identity. If provided with ``workload_identity_tenant_id``, they'll pass to DefaultAzureCredential_.
7676
* ``workload_identity_tenant_id``: ID of the application's Microsoft Entra tenant. Also called its "directory" ID. If provided with ``managed_identity_client_id``, they'll pass to DefaultAzureCredential_.
77+
* ``use_azure_identity_object``: If set to true, it will use credential of newer type: ClientSecretCredential or DefaultAzureCredential instead of ServicePrincipalCredentials or AzureIdentityCredentialAdapter.
78+
These newer credentials support get_token method which can be used to generate OAuth token with custom scope.
7779

7880
The entire extra column can be left out to fall back on DefaultAzureCredential_.
7981

providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py

Lines changed: 92 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,25 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
from typing import Any
19+
from typing import TYPE_CHECKING, Any
2020

2121
from azure.common.client_factory import get_client_from_auth_file, get_client_from_json_dict
2222
from azure.common.credentials import ServicePrincipalCredentials
23+
from azure.identity import ClientSecretCredential, DefaultAzureCredential
2324

2425
from airflow.exceptions import AirflowException
2526
from airflow.providers.microsoft.azure.utils import (
2627
AzureIdentityCredentialAdapter,
2728
add_managed_identity_connection_widgets,
29+
get_sync_default_azure_credential,
2830
)
2931
from airflow.providers.microsoft.azure.version_compat import BaseHook
3032

33+
if TYPE_CHECKING:
34+
from azure.core.credentials import AccessToken
35+
36+
from airflow.sdk import Connection
37+
3138

3239
class AzureBaseHook(BaseHook):
3340
"""
@@ -85,7 +92,7 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
8592
},
8693
}
8794

88-
def __init__(self, sdk_client: Any, conn_id: str = "azure_default"):
95+
def __init__(self, sdk_client: Any = None, conn_id: str = "azure_default"):
8996
self.sdk_client = sdk_client
9097
self.conn_id = conn_id
9198
super().__init__()
@@ -96,8 +103,9 @@ def get_conn(self) -> Any:
96103
97104
:return: the authenticated client.
98105
"""
106+
if not self.sdk_client:
107+
raise ValueError("`sdk_client` must be provided to AzureBaseHook to use `get_conn` method.")
99108
conn = self.get_connection(self.conn_id)
100-
tenant = conn.extra_dejson.get("tenantId")
101109
subscription_id = conn.extra_dejson.get("subscriptionId")
102110
key_path = conn.extra_dejson.get("key_path")
103111
if key_path:
@@ -111,22 +119,90 @@ def get_conn(self) -> Any:
111119
self.log.info("Getting connection using a JSON config.")
112120
return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json)
113121

114-
credentials: ServicePrincipalCredentials | AzureIdentityCredentialAdapter
122+
credentials = self.get_credential(conn=conn)
123+
124+
return self.sdk_client(
125+
credentials=credentials,
126+
subscription_id=subscription_id,
127+
)
128+
129+
def get_credential(
130+
self, *, conn: Connection | None = None
131+
) -> (
132+
ServicePrincipalCredentials
133+
| AzureIdentityCredentialAdapter
134+
| ClientSecretCredential
135+
| DefaultAzureCredential
136+
):
137+
"""
138+
Get Azure credential object for the connection.
139+
140+
Azure Identity based credential object (``ClientSecretCredential``, ``DefaultAzureCredential``) can be used to get OAuth token using ``get_token`` method.
141+
Older Credential objects (``ServicePrincipalCredentials``, ``AzureIdentityCredentialAdapter``) are supported for backward compatibility.
142+
143+
:return: The Azure credential object
144+
"""
145+
if not conn:
146+
conn = self.get_connection(self.conn_id)
147+
tenant = conn.extra_dejson.get("tenantId")
148+
credential: (
149+
ServicePrincipalCredentials
150+
| AzureIdentityCredentialAdapter
151+
| ClientSecretCredential
152+
| DefaultAzureCredential
153+
)
115154
if all([conn.login, conn.password, tenant]):
116-
self.log.info("Getting connection using specific credentials and subscription_id.")
117-
credentials = ServicePrincipalCredentials(
118-
client_id=conn.login, secret=conn.password, tenant=tenant
119-
)
155+
credential = self._get_client_secret_credential(conn)
120156
else:
121-
self.log.info("Using DefaultAzureCredential as credential")
122-
managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id")
123-
workload_identity_tenant_id = conn.extra_dejson.get("workload_identity_tenant_id")
124-
credentials = AzureIdentityCredentialAdapter(
157+
credential = self._get_default_azure_credential(conn)
158+
return credential
159+
160+
def _get_client_secret_credential(
161+
self, conn: Connection
162+
) -> ServicePrincipalCredentials | ClientSecretCredential:
163+
self.log.info("Getting credentials using specific credentials and subscription_id.")
164+
extra_dejson = conn.extra_dejson
165+
tenant = extra_dejson.get("tenantId")
166+
use_azure_identity_object = extra_dejson.get("use_azure_identity_object", False)
167+
if use_azure_identity_object:
168+
return ClientSecretCredential(
169+
client_id=conn.login, # type: ignore[arg-type]
170+
client_secret=conn.password, # type: ignore[arg-type]
171+
tenant_id=tenant, # type: ignore[arg-type]
172+
)
173+
return ServicePrincipalCredentials(client_id=conn.login, secret=conn.password, tenant=tenant)
174+
175+
def _get_default_azure_credential(
176+
self, conn: Connection
177+
) -> DefaultAzureCredential | AzureIdentityCredentialAdapter:
178+
self.log.info("Using DefaultAzureCredential as credential")
179+
extra_dejson = conn.extra_dejson
180+
managed_identity_client_id = extra_dejson.get("managed_identity_client_id")
181+
workload_identity_tenant_id = extra_dejson.get("workload_identity_tenant_id")
182+
use_azure_identity_object = extra_dejson.get("use_azure_identity_object", False)
183+
if use_azure_identity_object:
184+
return get_sync_default_azure_credential(
125185
managed_identity_client_id=managed_identity_client_id,
126186
workload_identity_tenant_id=workload_identity_tenant_id,
127187
)
128-
129-
return self.sdk_client(
130-
credentials=credentials,
131-
subscription_id=subscription_id,
188+
return AzureIdentityCredentialAdapter(
189+
managed_identity_client_id=managed_identity_client_id,
190+
workload_identity_tenant_id=workload_identity_tenant_id,
132191
)
192+
193+
def get_token(self, *scopes, **kwargs) -> AccessToken:
194+
"""
195+
Request an access token for `scopes`.
196+
197+
To use this method, set `use_azure_identity_object: True` in the connection extra field.
198+
ServicePrincipalCredentials and AzureIdentityCredentialAdapter don't support `get_token` method.
199+
"""
200+
credential = self.get_credential()
201+
if isinstance(credential, ServicePrincipalCredentials) or isinstance(
202+
credential, AzureIdentityCredentialAdapter
203+
):
204+
raise AttributeError(
205+
"ServicePrincipalCredentials and AzureIdentityCredentialAdapter don't support get_token method. "
206+
"Please set `use_azure_identity_object: True` in the connection extra field to use credential that support get_token method."
207+
)
208+
return credential.get_token(*scopes, **kwargs)

providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Connection = MagicMock() # type: ignore[misc]
3232

3333
MODULE = "airflow.providers.microsoft.azure.hooks.base_azure"
34+
UTILS = "airflow.providers.microsoft.azure.utils"
3435

3536

3637
class TestBaseAzureHook:
@@ -111,7 +112,7 @@ def test_get_conn_with_credentials(self, mock_spc, mocked_connection):
111112
indirect=True,
112113
)
113114
@patch("azure.common.credentials.ServicePrincipalCredentials")
114-
@patch("airflow.providers.microsoft.azure.hooks.base_azure.AzureIdentityCredentialAdapter")
115+
@patch(f"{MODULE}.AzureIdentityCredentialAdapter")
115116
def test_get_conn_fallback_to_azure_identity_credential_adapter(
116117
self,
117118
mock_credential_adapter,
@@ -133,3 +134,90 @@ def test_get_conn_fallback_to_azure_identity_credential_adapter(
133134
credentials=mock_credential,
134135
subscription_id="subscription_id",
135136
)
137+
138+
@patch(f"{MODULE}.ClientSecretCredential")
139+
@pytest.mark.parametrize(
140+
"mocked_connection",
141+
[
142+
Connection(
143+
conn_id="azure_default",
144+
login="my_login",
145+
password="my_password",
146+
extra={"tenantId": "my_tenant", "use_azure_identity_object": True},
147+
),
148+
],
149+
indirect=True,
150+
)
151+
def test_get_credential_with_client_secret(self, mock_spc, mocked_connection):
152+
mock_spc.return_value = "foo-bar"
153+
cred = AzureBaseHook().get_credential()
154+
155+
mock_spc.assert_called_once_with(
156+
client_id=mocked_connection.login,
157+
client_secret=mocked_connection.password,
158+
tenant_id=mocked_connection.extra_dejson["tenantId"],
159+
)
160+
assert cred == "foo-bar"
161+
162+
@patch(f"{UTILS}.DefaultAzureCredential")
163+
@pytest.mark.parametrize(
164+
"mocked_connection",
165+
[
166+
Connection(
167+
conn_id="azure_default",
168+
extra={"use_azure_identity_object": True},
169+
),
170+
],
171+
indirect=True,
172+
)
173+
def test_get_credential_with_azure_default_credential(self, mock_spc, mocked_connection):
174+
mock_spc.return_value = "foo-bar"
175+
cred = AzureBaseHook().get_credential()
176+
177+
mock_spc.assert_called_once_with()
178+
assert cred == "foo-bar"
179+
180+
@patch(f"{UTILS}.DefaultAzureCredential")
181+
@pytest.mark.parametrize(
182+
"mocked_connection",
183+
[
184+
Connection(
185+
conn_id="azure_default",
186+
extra={
187+
"managed_identity_client_id": "test_client_id",
188+
"workload_identity_tenant_id": "test_tenant_id",
189+
"use_azure_identity_object": True,
190+
},
191+
),
192+
],
193+
indirect=True,
194+
)
195+
def test_get_credential_with_azure_default_credential_with_extra(self, mock_spc, mocked_connection):
196+
mock_spc.return_value = "foo-bar"
197+
cred = AzureBaseHook().get_credential()
198+
199+
mock_spc.assert_called_once_with(
200+
managed_identity_client_id=mocked_connection.extra_dejson.get("managed_identity_client_id"),
201+
workload_identity_tenant_id=mocked_connection.extra_dejson.get("workload_identity_tenant_id"),
202+
additionally_allowed_tenants=[mocked_connection.extra_dejson.get("workload_identity_tenant_id")],
203+
)
204+
assert cred == "foo-bar"
205+
206+
@patch(f"{UTILS}.DefaultAzureCredential")
207+
@pytest.mark.parametrize(
208+
"mocked_connection",
209+
[
210+
Connection(
211+
conn_id="azure_default",
212+
extra={"use_azure_identity_object": True},
213+
),
214+
],
215+
indirect=True,
216+
)
217+
def test_get_token_with_azure_default_credential(self, mock_spc, mocked_connection):
218+
mock_spc.return_value.get_token.return_value = "new-token"
219+
scope = "custom_scope"
220+
token = AzureBaseHook().get_token(scope)
221+
222+
mock_spc.assert_called_once_with()
223+
assert token == "new-token"

0 commit comments

Comments
 (0)