Skip to content

Commit f6273e7

Browse files
authored
feat: unify database class (#1068)
1 parent 36747ba commit f6273e7

File tree

96 files changed

+868
-1169
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+868
-1169
lines changed

api/src/feeds/impl/feeds_api_impl.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from feeds.impl.error_handling import raise_http_error, raise_http_validation_error, convert_exception
4646
from middleware.request_context import is_user_email_restricted
4747
from utils.date_utils import valid_iso_date
48-
from utils.logger import Logger
48+
from shared.common.logging_utils import Logger
4949

5050
T = TypeVar("T", bound="BasicFeed")
5151

@@ -66,7 +66,7 @@ def __init__(self) -> None:
6666
def get_feed(self, id: str, db_session: Session) -> BasicFeed:
6767
"""Get the specified feed from the Mobility Database."""
6868
is_email_restricted = is_user_email_restricted()
69-
self.logger.info(f"User email is restricted: {is_email_restricted}")
69+
self.logger.debug(f"User email is restricted: {is_email_restricted}")
7070

7171
feed = (
7272
FeedFilter(stable_id=id, provider__ilike=None, producer_url__ilike=None, status=None)
@@ -98,7 +98,7 @@ def get_feeds(
9898
) -> List[BasicFeed]:
9999
"""Get some (or all) feeds from the Mobility Database."""
100100
is_email_restricted = is_user_email_restricted()
101-
self.logger.info(f"User email is restricted: {is_email_restricted}")
101+
self.logger.debug(f"User email is restricted: {is_email_restricted}")
102102
feed_filter = FeedFilter(
103103
status=status, provider__ilike=provider, producer_url__ilike=producer_url, stable_id=None
104104
)
@@ -137,7 +137,6 @@ def _get_gtfs_feed(
137137
query = get_gtfs_feeds_query(
138138
db_session=db_session, stable_id=stable_id, include_options_for_joinedload=include_options_for_joinedload
139139
)
140-
self.logger.debug("Query: %s", str(query.statement.compile(compile_kwargs={"literal_binds": True})))
141140
results = query.all()
142141
if len(results) == 0:
143142
return None

api/src/feeds/impl/models/validation_report_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from shared.database_gen.sqlacodegen_models import Validationreport
22
from feeds_gen.models.validation_report import ValidationReport
3-
from utils.logger import Logger
3+
from shared.common.logging_utils import Logger
44

55
DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S"
66

api/src/scripts/load_dataset_on_create.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from google.cloud.pubsub_v1.futures import Future
1111

1212
from shared.database_gen.sqlacodegen_models import Feed
13-
from utils.logger import Logger
13+
from shared.common.logging_utils import Logger
1414

1515
# Lazy create so we won't try to connect to google cloud when the file is imported.
1616
pubsub_client = None

api/src/scripts/populate_db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from shared.database.database import Database
1111
from shared.database_gen.sqlacodegen_models import Feed, Gtfsrealtimefeed, Gtfsfeed, Gbfsfeed
12-
from utils.logger import Logger
12+
from shared.common.logging_utils import Logger
1313

1414
if TYPE_CHECKING:
1515
from sqlalchemy.orm import Session

api/src/scripts/populate_db_test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Officialstatushistory,
1818
)
1919
from scripts.populate_db import set_up_configs, DatabasePopulateHelper
20-
from utils.logger import Logger
20+
from shared.common.logging_utils import Logger
2121
from typing import TYPE_CHECKING
2222

2323
if TYPE_CHECKING:
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import logging
2+
import os
3+
4+
5+
def get_env_logging_level():
6+
"""
7+
Get the logging level from the environment via OS variable LOGGING_LEVEL. Returns INFO if not set.
8+
"""
9+
return os.getenv("LOGGING_LEVEL", "INFO")
10+
11+
12+
class Logger:
13+
"""
14+
Util class for logging information, errors or warnings
15+
"""
16+
17+
def __init__(self, name):
18+
"""
19+
Initialize the logger
20+
"""
21+
formatter = logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s")
22+
23+
console_handler = logging.StreamHandler()
24+
console_handler.setFormatter(formatter)
25+
26+
self.logger = logging.getLogger(name)
27+
self.logger.addHandler(console_handler)
28+
self.logger.setLevel(get_env_logging_level())
29+
30+
def get_logger(self):
31+
"""
32+
Get the logger instance
33+
:return: the logger instance
34+
"""
35+
return self.logger

api/src/shared/database/database.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import uuid
66
from typing import Type, Callable
77
from dotenv import load_dotenv
8-
from sqlalchemy import create_engine
9-
from sqlalchemy.orm import load_only, Query, class_mapper, Session
8+
from sqlalchemy import create_engine, text, event
9+
from sqlalchemy.orm import load_only, Query, class_mapper, Session, mapper
1010
from shared.database_gen.sqlacodegen_models import Base, Feed, Gtfsfeed, Gtfsrealtimefeed, Gbfsfeed
1111
from sqlalchemy.orm import sessionmaker
1212
import logging
1313

14+
from shared.common.logging_utils import get_env_logging_level
15+
1416

1517
def generate_unique_id() -> str:
1618
"""
@@ -42,7 +44,48 @@ def configure_polymorphic_mappers():
4244
gbfsfeed_mapper.polymorphic_identity = Gbfsfeed.__tablename__.lower()
4345

4446

45-
def with_db_session(func):
47+
def set_cascade(mapper, class_):
48+
"""
49+
Set cascade for relationships in Gtfsfeed.
50+
This allows to delete/add the relationships when their respective relation array changes.
51+
"""
52+
if class_.__name__ == "Gtfsfeed":
53+
for rel in class_.__mapper__.relationships:
54+
if rel.key in [
55+
"redirectingids",
56+
"redirectingids_",
57+
"externalids",
58+
"externalids_",
59+
]:
60+
rel.cascade = "all, delete-orphan"
61+
62+
63+
def mapper_configure_listener(mapper, class_):
64+
"""
65+
Mapper configure listener
66+
"""
67+
set_cascade(mapper, class_)
68+
configure_polymorphic_mappers()
69+
70+
71+
# Add the mapper_configure_listener to the mapper_configured event
72+
event.listen(mapper, "mapper_configured", mapper_configure_listener)
73+
74+
75+
def refresh_materialized_view(session: "Session", view_name: str) -> bool:
76+
"""
77+
Refresh Materialized view by name.
78+
@return: True if the view was refreshed successfully, False otherwise
79+
"""
80+
try:
81+
session.execute(text(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {view_name}"))
82+
return True
83+
except Exception as error:
84+
logging.error(f"Error raised while refreshing view: {error}")
85+
return False
86+
87+
88+
def with_db_session(func=None, db_url: str | None = None):
4689
"""
4790
Decorator to handle the session management for the decorated function.
4891
@@ -58,12 +101,15 @@ def with_db_session(func):
58101
exception occurs, and closed in either case.
59102
- The session is then passed to the decorated function as the 'db_session' keyword argument.
60103
- If 'db_session' is already provided, it simply calls the decorated function with the existing session.
104+
- The echoed SQL queries will be logged if the environment variable LOGGING_LEVEL is set to DEBUG.
61105
"""
106+
if func is None:
107+
return lambda f: with_db_session(f, db_url=db_url)
62108

63109
def wrapper(*args, **kwargs):
64110
db_session = kwargs.get("db_session")
65111
if db_session is None:
66-
db = Database()
112+
db = Database(echo_sql=get_env_logging_level() == "DEBUG", feeds_database_url=db_url)
67113
with db.start_db_session() as session:
68114
kwargs["db_session"] = session
69115
return func(*args, **kwargs)
@@ -89,12 +135,16 @@ def __new__(cls, *args, **kwargs):
89135
cls.instance = object.__new__(cls)
90136
return cls.instance
91137

92-
def __init__(self, echo_sql=False):
138+
def __init__(self, echo_sql=False, feeds_database_url: str | None = None):
93139
"""
94140
Initializes the database instance
95-
:param echo_sql: whether to echo the SQL queries or not
96-
echo_sql set to False reduces the amount of information and noise going to the logs.
97-
In case of errors, the exceptions will still contain relevant information about the failing queries.
141+
142+
:param echo_sql: whether to echo the SQL queries or not echo_sql.
143+
False reduces the amount of information and noise going to the logs.
144+
In case of errors, the exceptions will still contain relevant information about the failing queries.
145+
146+
:param feeds_database_url: The URL of the target database.
147+
If it's None the URL will be assigned from the environment variable FEEDS_DATABASE_URL.
98148
"""
99149

100150
# This init function is called each time we call Database(), but in the case of a singleton, we only want to
@@ -107,7 +157,7 @@ def __init__(self, echo_sql=False):
107157
load_dotenv()
108158
self.logger = logging.getLogger(__name__)
109159
self.connection_attempts = 0
110-
database_url = os.getenv("FEEDS_DATABASE_URL")
160+
database_url = feeds_database_url if feeds_database_url else os.getenv("FEEDS_DATABASE_URL")
111161
if database_url is None:
112162
raise Exception("Database URL not provided.")
113163
self.pool_size = int(os.getenv("DB_POOL_SIZE", 10))

api/src/utils/logger.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -125,29 +125,3 @@ async def async_emit(self, record):
125125
jsonPayload=json_payload,
126126
)
127127
self.logger.info(json.dumps(log_record.__dict__))
128-
129-
130-
class Logger:
131-
"""
132-
Util class for logging information, errors or warnings
133-
"""
134-
135-
def __init__(self, name):
136-
"""
137-
Initialize the logger
138-
"""
139-
formatter = logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s")
140-
141-
console_handler = logging.StreamHandler()
142-
console_handler.setFormatter(formatter)
143-
144-
self.logger = logging.getLogger(name)
145-
self.logger.addHandler(console_handler)
146-
self.logger.setLevel(logging.DEBUG)
147-
148-
def get_logger(self):
149-
"""
150-
Get the logger instance
151-
:return: the logger instance
152-
"""
153-
return self.logger

api/tests/utils/test_logger.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import unittest
22
from unittest.mock import patch
33

4-
from utils.logger import HttpRequest, LogRecord, AsyncStreamHandler, GCPLogHandler, Logger
4+
from utils.logger import HttpRequest, LogRecord, AsyncStreamHandler, GCPLogHandler
5+
from shared.common.logging_utils import Logger
56

67

78
class TestLogger(unittest.TestCase):

functions-python/backfill_dataset_service_date_range/.coveragerc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
[run]
22
omit =
33
*/test*/*
4-
*/database_gen/*
5-
*/dataset_service/*
64
*/helpers/*
5+
*/shared/*
76

87
[report]
98
exclude_lines =

0 commit comments

Comments
 (0)