2121import pickle
2222from copy import deepcopy
2323from unittest import mock
24- from unittest .mock import MagicMock
2524
2625import pytest
2726from kubernetes .client import models as k8s
3736from airflow .utils .sqlalchemy import (
3837 ExecutorConfigType ,
3938 ensure_pod_is_valid_after_unpickling ,
39+ get_dialect_name ,
4040 is_sqlalchemy_v1 ,
4141 prohibit_commit ,
4242 with_row_locks ,
5252TEST_POD = k8s .V1Pod (spec = k8s .V1PodSpec (containers = [k8s .V1Container (name = "base" )]))
5353
5454
55+ class TestGetDialectName :
56+ def test_returns_dialect_name_when_present (self , mocker ):
57+ mock_session = mocker .Mock ()
58+ mock_bind = mocker .Mock ()
59+ mock_bind .dialect .name = "postgresql"
60+ mock_session .get_bind .return_value = mock_bind
61+
62+ assert get_dialect_name (mock_session ) == "postgresql"
63+
64+ def test_raises_when_no_bind (self , mocker ):
65+ mock_session = mocker .Mock ()
66+ mock_session .get_bind .return_value = None
67+
68+ with pytest .raises (ValueError , match = "No bind/engine is associated" ):
69+ get_dialect_name (mock_session )
70+
71+ def test_returns_none_when_dialect_has_no_name (self , mocker ):
72+ mock_session = mocker .Mock ()
73+ mock_bind = mocker .Mock ()
74+ # simulate dialect object without `name` attribute
75+ mock_bind .dialect = mock .Mock ()
76+ delattr (mock_bind .dialect , "name" ) if hasattr (mock_bind .dialect , "name" ) else None
77+ mock_session .get_bind .return_value = mock_bind
78+
79+ assert get_dialect_name (mock_session ) is None
80+
81+
5582class TestSqlAlchemyUtils :
5683 def setup_method (self ):
5784 session = Session ()
5885
5986 # make sure NOT to run in UTC. Only postgres supports storing
6087 # timezone information in the datetime field
61- if session . bind . dialect . name == "postgresql" :
88+ if get_dialect_name ( session ) == "postgresql" :
6289 session .execute (text ("SET timezone='Europe/Amsterdam'" ))
6390
6491 self .session = session
@@ -124,7 +151,7 @@ def test_process_bind_param_naive(self):
124151 dag .clear ()
125152
126153 @pytest .mark .parametrize (
127- "dialect, supports_for_update_of, use_row_level_lock_conf, expected_use_row_level_lock" ,
154+ ( "dialect" , " supports_for_update_of" , " use_row_level_lock_conf" , " expected_use_row_level_lock") ,
128155 [
129156 ("postgresql" , True , True , True ),
130157 ("postgresql" , True , False , False ),
@@ -192,7 +219,7 @@ def teardown_method(self):
192219
193220class TestExecutorConfigType :
194221 @pytest .mark .parametrize (
195- "input, expected" ,
222+ ( "input" , " expected") ,
196223 [
197224 ("anything" , "anything" ),
198225 (
@@ -206,13 +233,13 @@ class TestExecutorConfigType:
206233 ),
207234 ],
208235 )
209- def test_bind_processor (self , input , expected ):
236+ def test_bind_processor (self , input , expected , mocker ):
210237 """
211238 The returned bind processor should pickle the object as is, unless it is a dictionary with
212239 a pod_override node, in which case it should run it through BaseSerialization.
213240 """
214241 config_type = ExecutorConfigType ()
215- mock_dialect = MagicMock ()
242+ mock_dialect = mocker . MagicMock ()
216243 mock_dialect .dbapi = None
217244 process = config_type .bind_processor (mock_dialect )
218245 assert pickle .loads (process (input )) == expected
@@ -239,13 +266,13 @@ def test_bind_processor(self, input, expected):
239266 ),
240267 ],
241268 )
242- def test_result_processor (self , input ):
269+ def test_result_processor (self , input , mocker ):
243270 """
244271 The returned bind processor should pickle the object as is, unless it is a dictionary with
245272 a pod_override node whose value was serialized with BaseSerialization.
246273 """
247274 config_type = ExecutorConfigType ()
248- mock_dialect = MagicMock ()
275+ mock_dialect = mocker . MagicMock ()
249276 mock_dialect .dbapi = None
250277 process = config_type .result_processor (mock_dialect , None )
251278 result = process (input )
@@ -277,7 +304,7 @@ def __eq__(self, other):
277304 assert instance .compare_values (a , a ) is False
278305 assert instance .compare_values ("a" , "a" ) is True
279306
280- def test_result_processor_bad_pickled_obj (self ):
307+ def test_result_processor_bad_pickled_obj (self , mocker ):
281308 """
282309 If unpickled obj is missing attrs that curr lib expects
283310 """
@@ -309,7 +336,7 @@ def test_result_processor_bad_pickled_obj(self):
309336
310337 # get the result processor method
311338 config_type = ExecutorConfigType ()
312- mock_dialect = MagicMock ()
339+ mock_dialect = mocker . MagicMock ()
313340 mock_dialect .dbapi = None
314341 process = config_type .result_processor (mock_dialect , None )
315342
@@ -322,13 +349,13 @@ def test_result_processor_bad_pickled_obj(self):
322349
323350
324351@pytest .mark .parametrize (
325- "mock_version, expected_result" ,
352+ ( "mock_version" , " expected_result") ,
326353 [
327354 ("1.0.0" , True ), # Test 1: v1 identified as v1
328355 ("2.3.4" , False ), # Test 2: v2 not identified as v1
329356 ],
330357)
331- def test_is_sqlalchemy_v1 (mock_version , expected_result ):
332- with mock .patch ("airflow.utils.sqlalchemy.metadata" ) as mock_metadata :
333- mock_metadata .version .return_value = mock_version
334- assert is_sqlalchemy_v1 () == expected_result
358+ def test_is_sqlalchemy_v1 (mock_version , expected_result , mocker ):
359+ mock_metadata = mocker .patch ("airflow.utils.sqlalchemy.metadata" )
360+ mock_metadata .version .return_value = mock_version
361+ assert is_sqlalchemy_v1 () == expected_result
0 commit comments