Skip to content

Commit d8a498a

Browse files
committed
more refactoring
1 parent 4745fef commit d8a498a

File tree

5 files changed

+48
-45
lines changed

5 files changed

+48
-45
lines changed

functions-python/batch_datasets/tests/test_batch_datasets_main.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,14 @@ def test_batch_datasets(mock_client, mock_publish):
6464
]
6565

6666

67-
@patch("batch_datasets.src.main.start_db_session")
68-
def test_batch_datasets_exception(start_db_session_mock):
67+
@patch("batch_datasets.src.main.Database")
68+
def test_batch_datasets_exception(database_mock):
6969
exception_message = "Failure occurred"
70-
start_db_session_mock.side_effect = Exception(exception_message)
70+
mock_session = MagicMock()
71+
mock_session.side_effect = Exception(exception_message)
72+
73+
database_mock.return_value.start_db_session.return_value = mock_session
74+
7175
with pytest.raises(Exception) as exec_info:
7276
batch_datasets(Mock())
7377

functions-python/extract_location/tests/test_location_extraction.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -268,12 +268,12 @@ def test_extract_location_exception_2(
268268
"GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json",
269269
},
270270
)
271-
@patch("extract_location.src.main.start_db_session")
271+
@patch("extract_location.src.main.Database")
272272
@patch("extract_location.src.main.pubsub_v1.PublisherClient")
273273
@patch("extract_location.src.main.Logger")
274274
@patch("uuid.uuid4")
275275
def test_extract_location_batch(
276-
self, uuid_mock, logger_mock, publisher_client_mock, start_db_session_mock
276+
self, uuid_mock, logger_mock, publisher_client_mock, database_mock
277277
):
278278
mock_session = MagicMock()
279279
mock_dataset1 = Gtfsdataset(
@@ -300,7 +300,7 @@ def test_extract_location_batch(
300300
mock_dataset2,
301301
]
302302
uuid_mock.return_value = "batch-uuid"
303-
start_db_session_mock.return_value = mock_session
303+
database_mock.return_value.start_db_session.return_value = mock_session
304304

305305
mock_publisher = MagicMock()
306306
publisher_client_mock.return_value = mock_publisher
@@ -358,10 +358,13 @@ def test_extract_location_batch_no_topic_name(self, logger_mock):
358358
"GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json",
359359
},
360360
)
361-
@patch("extract_location.src.main.start_db_session")
361+
@patch("extract_location.src.main.Database")
362362
@patch("extract_location.src.main.Logger")
363-
def test_extract_location_batch_exception(self, logger_mock, start_db_session_mock):
364-
start_db_session_mock.side_effect = Exception("Database error")
363+
def test_extract_location_batch_exception(self, logger_mock, database_mock):
364+
mock_session = MagicMock()
365+
mock_session.side_effect = Exception("Database error")
366+
367+
database_mock.return_value.start_db_session.return_value = mock_session
365368

366369
response = extract_location_batch(None)
367370
self.assertEqual(response, ("Error while fetching datasets.", 500))

functions-python/gbfs_validator/tests/test_gbfs_validator.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class TestMainFunctions(unittest.TestCase):
2828
"VALIDATOR_URL": "https://mock-validator-url.com",
2929
},
3030
)
31-
@patch("gbfs_validator.src.main.start_db_session")
31+
@patch("gbfs_validator.src.main.Database")
3232
@patch("gbfs_validator.src.main.DatasetTraceService")
3333
@patch("gbfs_validator.src.main.fetch_gbfs_files")
3434
@patch("gbfs_validator.src.main.GBFSValidator.create_gbfs_json_with_bucket_paths")
@@ -47,11 +47,11 @@ def test_gbfs_validator_pubsub(
4747
mock_create_gbfs_json,
4848
mock_fetch_gbfs_files,
4949
mock_dataset_trace_service,
50-
mock_start_db_session,
50+
mock_database,
5151
):
5252
# Prepare mocks
5353
mock_session = MagicMock()
54-
mock_start_db_session.return_value = mock_session
54+
mock_database.return_value.start_db_session.return_value = mock_session
5555

5656
mock_trace_service = MagicMock()
5757
mock_dataset_trace_service.return_value = mock_trace_service
@@ -95,16 +95,16 @@ def test_gbfs_validator_pubsub(
9595
"PUBSUB_TOPIC_NAME": "mock-topic",
9696
},
9797
)
98-
@patch("gbfs_validator.src.main.start_db_session")
98+
@patch("gbfs_validator.src.main.Database")
9999
@patch("gbfs_validator.src.main.pubsub_v1.PublisherClient")
100100
@patch("gbfs_validator.src.main.fetch_all_gbfs_feeds")
101101
@patch("gbfs_validator.src.main.Logger")
102102
def test_gbfs_validator_batch(
103-
self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_start_db_session
103+
self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_database
104104
):
105105
# Prepare mocks
106106
mock_session = MagicMock()
107-
mock_start_db_session.return_value = mock_session
107+
mock_database.return_value.start_db_session.return_value = mock_session
108108

109109
mock_publisher = MagicMock()
110110
mock_publisher_client.return_value = mock_publisher
@@ -131,11 +131,11 @@ def test_gbfs_validator_batch_missing_topic(self, _): # mock_logger
131131
result = gbfs_validator_batch(None)
132132
self.assertEqual(result[1], 500)
133133

134-
@patch("gbfs_validator.src.main.start_db_session")
134+
@patch("gbfs_validator.src.main.Database")
135135
@patch("gbfs_validator.src.main.Logger")
136-
def test_fetch_all_gbfs_feeds(self, _, mock_start_db_session):
136+
def test_fetch_all_gbfs_feeds(self, _, mock_database):
137137
mock_session = MagicMock()
138-
mock_start_db_session.return_value = mock_session
138+
mock_database.return_value.start_db_session.return_value = mock_session
139139
mock_feed = MagicMock()
140140
mock_session.query.return_value.options.return_value.all.return_value = [
141141
mock_feed
@@ -144,14 +144,14 @@ def test_fetch_all_gbfs_feeds(self, _, mock_start_db_session):
144144
result = fetch_all_gbfs_feeds()
145145
self.assertEqual(result, [mock_feed])
146146

147-
mock_start_db_session.assert_called_once()
147+
mock_database.assert_called_once()
148148
mock_session.close.assert_called_once()
149149

150-
@patch("gbfs_validator.src.main.start_db_session")
150+
@patch("gbfs_validator.src.main.Database")
151151
@patch("gbfs_validator.src.main.Logger")
152-
def test_fetch_all_gbfs_feeds_exception(self, _, mock_start_db_session):
152+
def test_fetch_all_gbfs_feeds_exception(self, _, mock_database):
153153
mock_session = MagicMock()
154-
mock_start_db_session.return_value = mock_session
154+
mock_database.return_value.start_db_session.return_value = mock_session
155155

156156
# Simulate an exception when querying the database
157157
mock_session.query.side_effect = Exception("Database error")
@@ -161,19 +161,19 @@ def test_fetch_all_gbfs_feeds_exception(self, _, mock_start_db_session):
161161

162162
self.assertTrue("Database error" in str(context.exception))
163163

164-
mock_start_db_session.assert_called_once()
164+
mock_database.assert_called_once()
165165
mock_session.close.assert_called_once()
166166

167-
@patch("gbfs_validator.src.main.start_db_session")
168-
def test_fetch_all_gbfs_feeds_none_session(self, mock_start_db_session):
169-
mock_start_db_session.return_value = None
167+
@patch("gbfs_validator.src.main.Database")
168+
def test_fetch_all_gbfs_feeds_none_session(self, mock_database):
169+
mock_database.return_value = None
170170

171171
with self.assertRaises(Exception) as context:
172172
fetch_all_gbfs_feeds()
173173

174174
self.assertTrue("NoneType" in str(context.exception))
175175

176-
mock_start_db_session.assert_called_once()
176+
mock_database.assert_called_once()
177177

178178
@patch.dict(
179179
os.environ,
@@ -199,16 +199,16 @@ def test_gbfs_validator_batch_fetch_exception(self, _, mock_fetch_all_gbfs_feeds
199199
"PUBSUB_TOPIC_NAME": "mock-topic",
200200
},
201201
)
202-
@patch("gbfs_validator.src.main.start_db_session")
202+
@patch("gbfs_validator.src.main.Database")
203203
@patch("gbfs_validator.src.main.pubsub_v1.PublisherClient")
204204
@patch("gbfs_validator.src.main.fetch_all_gbfs_feeds")
205205
@patch("gbfs_validator.src.main.Logger")
206206
def test_gbfs_validator_batch_publish_exception(
207-
self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_start_db_session
207+
self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_database
208208
):
209209
# Prepare mocks
210210
mock_session = MagicMock()
211-
mock_start_db_session.return_value = mock_session
211+
mock_database.return_value.start_db_session.return_value = mock_session
212212

213213
mock_publisher_client.side_effect = Exception("Pub/Sub error")
214214

functions-python/validation_to_ndjson/src/utils/locations.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from sqlalchemy.orm import joinedload
55

66
from database_gen.sqlacodegen_models import Feed, Location
7-
from helpers.database import start_db_session
7+
from helpers.database import Database
88

99

1010
def get_feed_location(data_type: str, stable_id: str) -> List[Location]:
@@ -14,9 +14,8 @@ def get_feed_location(data_type: str, stable_id: str) -> List[Location]:
1414
@param stable_id: The stable ID of the feed.
1515
@return: A list of locations.
1616
"""
17-
session = None
18-
try:
19-
session = start_db_session(os.getenv("FEEDS_DATABASE_URL"))
17+
db = Database(database_url=os.getenv("FEEDS_DATABASE_URL"))
18+
with db.start_db_session() as session:
2019
feeds = (
2120
session.query(Feed)
2221
.filter(Feed.data_type == data_type)
@@ -25,6 +24,3 @@ def get_feed_location(data_type: str, stable_id: str) -> List[Location]:
2524
.all()
2625
)
2726
return feeds[0].locations if feeds is not None and len(feeds) > 0 else []
28-
finally:
29-
if session:
30-
session.close()

functions-python/validation_to_ndjson/tests/test_locations.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55

66

77
class TestFeedsLocations(unittest.TestCase):
8-
@patch("validation_to_ndjson.src.utils.locations.start_db_session")
8+
@patch("validation_to_ndjson.src.utils.locations.Database")
99
@patch("validation_to_ndjson.src.utils.locations.os.getenv")
1010
@patch("validation_to_ndjson.src.utils.locations.joinedload")
11-
def test_get_feeds_locations_map(self, _, mock_getenv, mock_start_db_session):
11+
def test_get_feeds_locations_map(self, _, mock_getenv, mock_database):
1212
mock_getenv.return_value = "fake_database_url"
1313

1414
mock_session = MagicMock()
15-
mock_start_db_session.return_value = mock_session
15+
mock_database.return_value.start_db_session.return_value = mock_session
1616

1717
mock_feed = MagicMock()
1818
mock_feed.stable_id = "feed1"
@@ -28,18 +28,18 @@ def test_get_feeds_locations_map(self, _, mock_getenv, mock_start_db_session):
2828
mock_session.query.return_value = mock_query
2929
result = get_feed_location("gtfs", "feed1")
3030

31-
mock_start_db_session.assert_called_once_with("fake_database_url")
31+
mock_database.assert_called_once_with("fake_database_url")
3232
mock_session.query.assert_called_once() # Verify that query was called
3333
mock_query.filter.assert_called_once() # Verify that filter was applied
3434
mock_query.filter.return_value.filter.return_value.options.assert_called_once()
3535
mock_query.filter.return_value.filter.return_value.options.return_value.all.assert_called_once()
3636

3737
self.assertEqual(result, [mock_location1, mock_location2]) # Verify the mapping
3838

39-
@patch("validation_to_ndjson.src.utils.locations.start_db_session")
40-
def test_get_feeds_locations_map_no_feeds(self, mock_start_db_session):
39+
@patch("validation_to_ndjson.src.utils.locations.Database")
40+
def test_get_feeds_locations_map_no_feeds(self, mock_database):
4141
mock_session = MagicMock()
42-
mock_start_db_session.return_value = mock_session
42+
mock_database.return_value.start_db_session.return_value = mock_session
4343

4444
mock_query = MagicMock()
4545
mock_query.filter.return_value.filter.return_value.options.return_value.all.return_value = (
@@ -50,5 +50,5 @@ def test_get_feeds_locations_map_no_feeds(self, mock_start_db_session):
5050

5151
result = get_feed_location("test_data_type", "test_stable_id")
5252

53-
mock_start_db_session.assert_called_once()
53+
mock_database.return_value.start_db_session.assert_called_once()
5454
self.assertEqual(result, []) # The result should be an empty dictionary

0 commit comments

Comments
 (0)