Skip to content

Commit fb41851

Browse files
committed
update connection managers
1 parent bb8009e commit fb41851

File tree

3 files changed

+95
-103
lines changed

3 files changed

+95
-103
lines changed

dbt/adapters/sqlserver/sql_server_connection_manager.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,21 @@
22

33
import pyodbc
44
from azure.core.credentials import AccessToken
5-
from azure.identity import (
6-
ClientSecretCredential,
7-
ManagedIdentityCredential,
8-
)
9-
from dbt.contracts.connection import Connection, ConnectionState
10-
from dbt.events import AdapterLogger
11-
12-
from dbt.adapters.sqlserver import __version__
13-
from dbt.adapters.sqlserver.sql_server_credentials import SQLServerCredentials
14-
5+
from azure.identity import ClientSecretCredential, ManagedIdentityCredential
156
from dbt.adapters.fabric import FabricConnectionManager
16-
from dbt.adapters.fabric.fabric_connection_manager import AZURE_AUTH_FUNCTIONS as AZURE_AUTH_FUNCTIONS_FABRIC
17-
7+
from dbt.adapters.fabric.fabric_connection_manager import (
8+
AZURE_AUTH_FUNCTIONS as AZURE_AUTH_FUNCTIONS_FABRIC,
9+
)
1810
from dbt.adapters.fabric.fabric_connection_manager import (
1911
AZURE_CREDENTIAL_SCOPE,
2012
bool_to_connection_string_arg,
2113
get_pyodbc_attrs_before,
2214
)
15+
from dbt.contracts.connection import Connection, ConnectionState
16+
from dbt.events import AdapterLogger
2317

18+
from dbt.adapters.sqlserver import __version__
19+
from dbt.adapters.sqlserver.sql_server_credentials import SQLServerCredentials
2420

2521
AZURE_AUTH_FUNCTION_TYPE = Callable[[SQLServerCredentials], AccessToken]
2622

@@ -82,9 +78,9 @@ def open(cls, connection: Connection) -> Connection:
8278
return connection
8379

8480
credentials = cls.get_credentials(connection.credentials)
85-
if credentials.authentication != 'sql':
81+
if credentials.authentication != "sql":
8682
return super().open(connection)
87-
83+
8884
# sql login authentication
8985

9086
con_str = [f"DRIVER={{{credentials.driver}}}"]
@@ -158,4 +154,3 @@ def connect():
158154
retry_limit=credentials.retries,
159155
retryable_exceptions=retryable_exceptions,
160156
)
161-

dbt/adapters/sqlserver/sql_server_credentials.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33

44
from dbt.adapters.fabric import FabricCredentials
55

6+
67
@dataclass
78
class SQLServerCredentials(FabricCredentials):
89
port: Optional[int] = 1433
10+
authentication: Optional[str] = "sql"
911

1012
@property
1113
def type(self):
1214
return "sqlserver"
1315

1416
def _connection_keys(self):
15-
return super()._connection_keys() +("port",)
17+
return super()._connection_keys() + ("port",)
Lines changed: 82 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
1-
import datetime as dt
2-
import json
3-
from unittest import mock
4-
51
import pytest
62
from azure.identity import AzureCliCredential
73

8-
from dbt.adapters.sqlserver.sql_server_connection_manager import (
4+
from dbt.adapters.sqlserver.sql_server_connection_manager import ( # byte_array_to_datetime,
95
bool_to_connection_string_arg,
10-
byte_array_to_datetime,
116
get_pyodbc_attrs_before,
127
)
138
from dbt.adapters.sqlserver.sql_server_credentials import SQLServerCredentials
@@ -28,22 +23,22 @@ def credentials() -> SQLServerCredentials:
2823
return credentials
2924

3025

31-
@pytest.fixture
32-
def mock_cli_access_token() -> str:
33-
access_token = "access token"
34-
expected_expires_on = 1602015811
35-
successful_output = json.dumps(
36-
{
37-
"expiresOn": dt.datetime.fromtimestamp(expected_expires_on).strftime(
38-
"%Y-%m-%d %H:%M:%S.%f"
39-
),
40-
"accessToken": access_token,
41-
"subscription": "some-guid",
42-
"tenant": "some-guid",
43-
"tokenType": "Bearer",
44-
}
45-
)
46-
return successful_output
26+
# @pytest.fixture
27+
# def mock_cli_access_token() -> str:
28+
# access_token = "access token"
29+
# expected_expires_on = 1602015811
30+
# successful_output = json.dumps(
31+
# {
32+
# "expiresOn": dt.datetime.fromtimestamp(expected_expires_on).strftime(
33+
# "%Y-%m-%d %H:%M:%S.%f"
34+
# ),
35+
# "accessToken": access_token,
36+
# "subscription": "some-guid",
37+
# "tenant": "some-guid",
38+
# "tokenType": "Bearer",
39+
# }
40+
# )
41+
# return successful_output
4742

4843

4944
def test_get_pyodbc_attrs_before_empty_dict_when_service_principal(
@@ -56,20 +51,20 @@ def test_get_pyodbc_attrs_before_empty_dict_when_service_principal(
5651
assert attrs_before == {}
5752

5853

59-
@pytest.mark.parametrize("authentication", ["CLI", "cli", "cLi"])
60-
def test_get_pyodbc_attrs_before_contains_access_token_key_for_cli_authentication(
61-
credentials: SQLServerCredentials,
62-
authentication: str,
63-
mock_cli_access_token: str,
64-
) -> None:
65-
"""
66-
When the cli authentication is used, the attrs before should contain an
67-
access token key.
68-
"""
69-
credentials.authentication = authentication
70-
with mock.patch(CHECK_OUTPUT, mock.Mock(return_value=mock_cli_access_token)):
71-
attrs_before = get_pyodbc_attrs_before(credentials)
72-
assert 1256 in attrs_before.keys()
54+
# @pytest.mark.parametrize("authentication", ["CLI", "cli", "cLi"])
55+
# def test_get_pyodbc_attrs_before_contains_access_token_key_for_cli_authentication(
56+
# credentials: SQLServerCredentials,
57+
# authentication: str,
58+
# mock_cli_access_token: str,
59+
# ) -> None:
60+
# """
61+
# When the cli authentication is used, the attrs before should contain an
62+
# access token key.
63+
# """
64+
# credentials.authentication = authentication
65+
# with mock.patch(CHECK_OUTPUT, mock.Mock(return_value=mock_cli_access_token)):
66+
# attrs_before = get_pyodbc_attrs_before(credentials)
67+
# assert 1256 in attrs_before.keys()
7368

7469

7570
@pytest.mark.parametrize(
@@ -79,54 +74,54 @@ def test_bool_to_connection_string_arg(key: str, value: bool, expected: str) ->
7974
assert bool_to_connection_string_arg(key, value) == expected
8075

8176

82-
@pytest.mark.parametrize(
83-
"value, expected_datetime, expected_str",
84-
[
85-
(
86-
bytes(
87-
[
88-
0xE6,
89-
0x07, # 2022 year unsigned short
90-
0x0C,
91-
0x00, # 12 month unsigned short
92-
0x11,
93-
0x00, # 17 day unsigned short
94-
0x11,
95-
0x00, # 17 hour unsigned short
96-
0x34,
97-
0x00, # 52 minute unsigned short
98-
0x12,
99-
0x00, # 18 second unsigned short
100-
0xBC,
101-
0xCC,
102-
0x5B,
103-
0x07, # 123456700 10⁻⁷ second unsigned long
104-
0xFE,
105-
0xFF, # -2 offset hour signed short
106-
0xE2,
107-
0xFF, # -30 offset minute signed short
108-
]
109-
),
110-
dt.datetime(
111-
year=2022,
112-
month=12,
113-
day=17,
114-
hour=17,
115-
minute=52,
116-
second=18,
117-
microsecond=123456700 // 1000, # 10⁻⁶ second
118-
tzinfo=dt.timezone(dt.timedelta(hours=-2, minutes=-30)),
119-
),
120-
"2022-12-17 17:52:18.123456-02:30",
121-
)
122-
],
123-
)
124-
def test_byte_array_to_datetime(
125-
value: bytes, expected_datetime: dt.datetime, expected_str: str
126-
) -> None:
127-
"""
128-
Assert SQL_SS_TIMESTAMPOFFSET_STRUCT bytes are converted to datetime and str
129-
https://learn.microsoft.com/sql/relational-databases/native-client-odbc-date-time/data-type-support-for-odbc-date-and-time-improvements#sql_ss_timestampoffset_struct
130-
"""
131-
assert byte_array_to_datetime(value) == expected_datetime
132-
assert str(byte_array_to_datetime(value)) == expected_str
77+
# @pytest.mark.parametrize(
78+
# "value, expected_datetime, expected_str",
79+
# [
80+
# (
81+
# bytes(
82+
# [
83+
# 0xE6,
84+
# 0x07, # 2022 year unsigned short
85+
# 0x0C,
86+
# 0x00, # 12 month unsigned short
87+
# 0x11,
88+
# 0x00, # 17 day unsigned short
89+
# 0x11,
90+
# 0x00, # 17 hour unsigned short
91+
# 0x34,
92+
# 0x00, # 52 minute unsigned short
93+
# 0x12,
94+
# 0x00, # 18 second unsigned short
95+
# 0xBC,
96+
# 0xCC,
97+
# 0x5B,
98+
# 0x07, # 123456700 10⁻⁷ second unsigned long
99+
# 0xFE,
100+
# 0xFF, # -2 offset hour signed short
101+
# 0xE2,
102+
# 0xFF, # -30 offset minute signed short
103+
# ]
104+
# ),
105+
# dt.datetime(
106+
# year=2022,
107+
# month=12,
108+
# day=17,
109+
# hour=17,
110+
# minute=52,
111+
# second=18,
112+
# microsecond=123456700 // 1000, # 10⁻⁶ second
113+
# tzinfo=dt.timezone(dt.timedelta(hours=-2, minutes=-30)),
114+
# ),
115+
# "2022-12-17 17:52:18.123456-02:30",
116+
# )
117+
# ],
118+
# )
119+
# def test_byte_array_to_datetime(
120+
# value: bytes, expected_datetime: dt.datetime, expected_str: str
121+
# ) -> None:
122+
# """
123+
# Assert SQL_SS_TIMESTAMPOFFSET_STRUCT bytes are converted to datetime and str
124+
# https://learn.microsoft.com/sql/relational-databases/native-client-odbc-date-time/data-type-support-for-odbc-date-and-time-improvements#sql_ss_timestampoffset_struct
125+
# """
126+
# assert byte_array_to_datetime(value) == expected_datetime
127+
# assert str(byte_array_to_datetime(value)) == expected_str

0 commit comments

Comments
 (0)