Skip to content

Commit d8b3300

Browse files
Azure IAM/Entra ID support for SnowflakeHook (apache#55874)
Co-authored-by: Karun Poudel <64540927+karunpoudel-chr@users.noreply.github.com>
1 parent 9ef4f12 commit d8b3300

File tree

8 files changed

+191
-15
lines changed

8 files changed

+191
-15
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
.. Licensed to the Apache Software Foundation (ASF) under one
2+
or more contributor license agreements. See the NOTICE file
3+
distributed with this work for additional information
4+
regarding copyright ownership. The ASF licenses this file
5+
to you under the Apache License, Version 2.0 (the
6+
"License"); you may not use this file except in compliance
7+
with the License. You may obtain a copy of the License at
8+
9+
.. http://www.apache.org/licenses/LICENSE-2.0
10+
11+
.. Unless required by applicable law or agreed to in writing,
12+
software distributed under the License is distributed on an
13+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
KIND, either express or implied. See the License for the
15+
specific language governing permissions and limitations
16+
under the License.
17+
18+
.. include:: /../../../devel-common/src/sphinx_exts/includes/providers-configurations-ref.rst
19+
.. include:: /../../../devel-common/src/sphinx_exts/includes/sections-and-options.rst

providers/snowflake/docs/connections/snowflake.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ Extra (optional)
6161
* ``token_endpoint``: Specify token endpoint for external OAuth provider.
6262
* ``grant_type``: Specify grant type for OAuth authentication. Currently supported: ``refresh_token`` (default), ``client_credentials``.
6363
* ``refresh_token``: Specify refresh_token for OAuth connection.
64+
* ``azure_conn_id``: Azure Connection ID to be used for retrieving the OAuth token using Azure Entra authentication. Login and Password fields aren't required when using this method. Scope for the Azure OAuth token can be set in the config option ``azure_oauth_scope`` under the section ``[snowflake]``. Requires `apache-airflow-providers-microsoft-azure>=12.8.0`.
6465
* ``private_key_file``: Specify the path to the private key file.
6566
* ``private_key_content``: Specify the content of the private key file in base64 encoded format. You can use the following Python code to encode the private key:
6667

providers/snowflake/docs/index.rst

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
:maxdepth: 1
4444
:caption: References
4545

46+
Configuration <configurations-ref>
4647
Python API <_api/airflow/providers/snowflake/index>
4748

4849
.. toctree::
@@ -127,13 +128,14 @@ You can install such cross-provider dependencies when installing from PyPI. For
127128
pip install apache-airflow-providers-snowflake[common.compat]
128129
129130
130-
================================================================================================================== =================
131-
Dependent package Extra
132-
================================================================================================================== =================
133-
`apache-airflow-providers-common-compat <https://airflow.apache.org/docs/apache-airflow-providers-common-compat>`_ ``common.compat``
134-
`apache-airflow-providers-common-sql <https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_ ``common.sql``
135-
`apache-airflow-providers-openlineage <https://airflow.apache.org/docs/apache-airflow-providers-openlineage>`_ ``openlineage``
136-
================================================================================================================== =================
131+
====================================================================================================================== =================
132+
Dependent package Extra
133+
====================================================================================================================== =================
134+
`apache-airflow-providers-common-compat <https://airflow.apache.org/docs/apache-airflow-providers-common-compat>`_ ``common.compat``
135+
`apache-airflow-providers-common-sql <https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_ ``common.sql``
136+
`apache-airflow-providers-openlineage <https://airflow.apache.org/docs/apache-airflow-providers-openlineage>`_ ``openlineage``
137+
`apache-airflow-providers-microsoft-azure <https://airflow.apache.org/docs/apache-airflow-providers-microsoft-azure>`_ ``microsoft.azure``
138+
====================================================================================================================== =================
137139

138140
Downloading official packages
139141
-----------------------------

providers/snowflake/provider.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,16 @@ triggers:
143143
- integration-name: Snowflake
144144
python-modules:
145145
- airflow.providers.snowflake.triggers.snowflake_trigger
146+
147+
config:
148+
snowflake:
149+
description: |
150+
Configuration for Snowflake hooks and operators.
151+
options:
152+
azure_oauth_scope:
153+
description: |
154+
The scope to use while retrieving OAuth token for Snowflake from Azure Entra authentication.
155+
version_added: 6.6.0
156+
type: string
157+
example: ~
158+
default: "api://snowflake_oauth_server/.default"

providers/snowflake/pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ dependencies = [
7575
# The optional dependencies should be modified in place in the generated file
7676
# Any change in the dependencies is preserved when the file is regenerated
7777
[project.optional-dependencies]
78+
"microsoft.azure" = [
79+
"apache-airflow-providers-microsoft-azure"
80+
]
7881
"openlineage" = [
7982
"apache-airflow-providers-openlineage>=2.3.0"
8083
]
@@ -86,6 +89,7 @@ dev = [
8689
"apache-airflow-devel-common",
8790
"apache-airflow-providers-common-compat",
8891
"apache-airflow-providers-common-sql",
92+
"apache-airflow-providers-microsoft-azure",
8993
"apache-airflow-providers-openlineage",
9094
# Additional devel dependencies (do not remove this line and add extra development dependencies)
9195
"responses>=0.25.0",

providers/snowflake/src/airflow/providers/snowflake/get_provider_info.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,18 @@ def get_provider_info():
9494
"python-modules": ["airflow.providers.snowflake.triggers.snowflake_trigger"],
9595
}
9696
],
97+
"config": {
98+
"snowflake": {
99+
"description": "Configuration for Snowflake hooks and operators.\n",
100+
"options": {
101+
"azure_oauth_scope": {
102+
"description": "The scope to use while retrieving OAuth token for Snowflake from Azure Entra authentication.\n",
103+
"version_added": "6.6.0",
104+
"type": "string",
105+
"example": None,
106+
"default": "api://snowflake_oauth_server/.default",
107+
}
108+
},
109+
}
110+
},
97111
}

providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,18 @@
3636
from snowflake.sqlalchemy import URL
3737
from sqlalchemy import create_engine
3838

39+
from airflow.configuration import conf
3940
from airflow.exceptions import AirflowException
4041
from airflow.providers.common.sql.hooks.handlers import return_single_query_results
4142
from airflow.providers.common.sql.hooks.sql import DbApiHook
4243
from airflow.providers.snowflake.utils.openlineage import fix_snowflake_sqlalchemy_uri
4344
from airflow.utils.strings import to_boolean
4445

46+
try:
47+
from airflow.sdk import Connection
48+
except ImportError:
49+
from airflow.models.connection import Connection # type: ignore[assignment]
50+
4551
T = TypeVar("T")
4652
if TYPE_CHECKING:
4753
from airflow.providers.openlineage.extractors import OperatorLineage
@@ -94,6 +100,7 @@ class SnowflakeHook(DbApiHook):
94100
hook_name = "Snowflake"
95101
supports_autocommit = True
96102
_test_connection_sql = "select 1"
103+
default_azure_oauth_scope = "api://snowflake_oauth_server/.default"
97104

98105
@classmethod
99106
def get_connection_form_widgets(cls) -> dict[str, Any]:
@@ -246,6 +253,40 @@ def get_oauth_token(
246253
token = response.json()["access_token"]
247254
return token
248255

256+
def get_azure_oauth_token(self, azure_conn_id: str) -> str:
257+
"""
258+
Generate OAuth access token using Azure connection id.
259+
260+
This uses AzureBaseHook on the connection id to retrieve the token. Scope for the OAuth token can be
261+
set in the config option ``azure_oauth_scope`` under the section ``[snowflake]``.
262+
263+
:param azure_conn_id: The connection id for the Azure connection that will be used to fetch the token.
264+
:raises AttributeError: If AzureBaseHook does not have a get_token method which happens when
265+
package apache-airflow-providers-microsoft-azure<12.8.0.
266+
:returns: The OAuth access token string.
267+
"""
268+
if TYPE_CHECKING:
269+
from airflow.providers.microsoft.azure.hooks.azure_base import AzureBaseHook
270+
271+
try:
272+
azure_conn = Connection.get(azure_conn_id)
273+
except AttributeError:
274+
azure_conn = Connection.get_connection_from_secrets(azure_conn_id) # type: ignore[attr-defined]
275+
azure_base_hook: AzureBaseHook = azure_conn.get_hook()
276+
scope = conf.get("snowflake", "azure_oauth_scope", fallback=self.default_azure_oauth_scope)
277+
try:
278+
token = azure_base_hook.get_token(scope).token
279+
except AttributeError as e:
280+
if e.name == "get_token" and e.obj == azure_base_hook:
281+
raise AttributeError(
282+
"'AzureBaseHook' object has no attribute 'get_token'. "
283+
"Please upgrade apache-airflow-providers-microsoft-azure>=12.8.0",
284+
name=e.name,
285+
obj=e.obj,
286+
) from e
287+
raise
288+
return token
289+
249290
@cached_property
250291
def _get_conn_params(self) -> dict[str, str | None]:
251292
"""
@@ -349,14 +390,17 @@ def _get_conn_params(self) -> dict[str, str | None]:
349390
conn_config["authenticator"] = "oauth"
350391

351392
if conn_config.get("authenticator") == "oauth":
352-
token_endpoint = self._get_field(extra_dict, "token_endpoint") or ""
353-
conn_config["client_id"] = conn.login
354-
conn_config["client_secret"] = conn.password
355-
conn_config["token"] = self.get_oauth_token(
356-
conn_config=conn_config,
357-
token_endpoint=token_endpoint,
358-
grant_type=extra_dict.get("grant_type", "refresh_token"),
359-
)
393+
if extra_dict.get("azure_conn_id"):
394+
conn_config["token"] = self.get_azure_oauth_token(extra_dict["azure_conn_id"])
395+
else:
396+
token_endpoint = self._get_field(extra_dict, "token_endpoint") or ""
397+
conn_config["client_id"] = conn.login
398+
conn_config["client_secret"] = conn.password
399+
conn_config["token"] = self.get_oauth_token(
400+
conn_config=conn_config,
401+
token_endpoint=token_endpoint,
402+
grant_type=extra_dict.get("grant_type", "refresh_token"),
403+
)
360404

361405
conn_config.pop("login", None)
362406
conn_config.pop("user", None)

providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,44 @@ def test_get_conn_params_should_support_oauth_with_client_credentials(
666666
assert "region" in conn_params_extra_keys
667667
assert "account" in conn_params_extra_keys
668668

669+
def test_get_conn_params_should_support_oauth_with_azure_conn_id(self, mocker):
670+
azure_conn_id = "azure_test_conn"
671+
mock_azure_token = "azure_test_token"
672+
connection_kwargs = {
673+
"extra": {
674+
"database": "db",
675+
"account": "airflow",
676+
"region": "af_region",
677+
"warehouse": "af_wh",
678+
"authenticator": "oauth",
679+
"azure_conn_id": azure_conn_id,
680+
},
681+
}
682+
683+
mock_connection_class = mocker.patch("airflow.providers.snowflake.hooks.snowflake.Connection")
684+
mock_azure_base_hook = mock_connection_class.get.return_value.get_hook.return_value
685+
mock_azure_base_hook.get_token.return_value.token = mock_azure_token
686+
687+
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
688+
hook = SnowflakeHook(snowflake_conn_id="test_conn")
689+
conn_params = hook._get_conn_params
690+
691+
# Check AzureBaseHook initialization and get_token call args
692+
mock_connection_class.get.assert_called_once_with(azure_conn_id)
693+
mock_azure_base_hook.get_token.assert_called_once_with(SnowflakeHook.default_azure_oauth_scope)
694+
695+
assert "authenticator" in conn_params
696+
assert conn_params["authenticator"] == "oauth"
697+
assert "token" in conn_params
698+
assert conn_params["token"] == mock_azure_token
699+
700+
assert "user" not in conn_params
701+
assert "password" not in conn_params
702+
assert "refresh_token" not in conn_params
703+
# Mandatory fields to generate account_identifier `https://<account>.<region>`
704+
assert "region" in conn_params
705+
assert "account" in conn_params
706+
669707
def test_should_add_partner_info(self):
670708
with mock.patch.dict(
671709
"os.environ",
@@ -1054,3 +1092,44 @@ def test_get_oauth_token_with_token_endpoint(self, mock_conn_param, requests_pos
10541092
headers={"Content-Type": "application/x-www-form-urlencoded"},
10551093
auth=basic_auth,
10561094
)
1095+
1096+
def test_get_azure_oauth_token(self, mocker):
1097+
"""Test get_azure_oauth_token method gets token from provided connection id"""
1098+
azure_conn_id = "azure_test_conn"
1099+
mock_azure_token = "azure_test_token"
1100+
1101+
mock_connection_class = mocker.patch("airflow.providers.snowflake.hooks.snowflake.Connection")
1102+
mock_azure_base_hook = mock_connection_class.get.return_value.get_hook.return_value
1103+
mock_azure_base_hook.get_token.return_value.token = mock_azure_token
1104+
1105+
hook = SnowflakeHook(snowflake_conn_id="mock_conn_id")
1106+
token = hook.get_azure_oauth_token(azure_conn_id)
1107+
1108+
# Check AzureBaseHook initialization and get_token call args
1109+
mock_connection_class.get.assert_called_once_with(azure_conn_id)
1110+
mock_azure_base_hook.get_token.assert_called_once_with(SnowflakeHook.default_azure_oauth_scope)
1111+
assert token == mock_azure_token
1112+
1113+
def test_get_azure_oauth_token_expect_failure_on_get_token(self, mocker):
1114+
"""Test get_azure_oauth_token method gets token from provided connection id"""
1115+
1116+
class MockAzureBaseHookWithoutGetToken:
1117+
def __init__(self):
1118+
pass
1119+
1120+
azure_conn_id = "azure_test_conn"
1121+
mock_connection_class = mocker.patch("airflow.providers.snowflake.hooks.snowflake.Connection")
1122+
mock_connection_class.get.return_value.get_hook.return_value = MockAzureBaseHookWithoutGetToken()
1123+
1124+
hook = SnowflakeHook(snowflake_conn_id="mock_conn_id")
1125+
with pytest.raises(
1126+
AttributeError,
1127+
match=(
1128+
"'AzureBaseHook' object has no attribute 'get_token'. "
1129+
"Please upgrade apache-airflow-providers-microsoft-azure>="
1130+
),
1131+
):
1132+
hook.get_azure_oauth_token(azure_conn_id)
1133+
1134+
# Check AzureBaseHook initialization
1135+
mock_connection_class.get.assert_called_once_with(azure_conn_id)

0 commit comments

Comments
 (0)