55import uuid
66from typing import Type , Callable
77from dotenv import load_dotenv
8- from sqlalchemy import create_engine
9- from sqlalchemy .orm import load_only , Query , class_mapper , Session
8+ from sqlalchemy import create_engine , text , event
9+ from sqlalchemy .orm import load_only , Query , class_mapper , Session , mapper
1010from shared .database_gen .sqlacodegen_models import Base , Feed , Gtfsfeed , Gtfsrealtimefeed , Gbfsfeed
1111from sqlalchemy .orm import sessionmaker
1212import logging
1313
14+ from shared .common .logging_utils import get_env_logging_level
15+
1416
1517def generate_unique_id () -> str :
1618 """
@@ -42,7 +44,48 @@ def configure_polymorphic_mappers():
4244 gbfsfeed_mapper .polymorphic_identity = Gbfsfeed .__tablename__ .lower ()
4345
4446
45- def with_db_session (func ):
47+ def set_cascade (mapper , class_ ):
48+ """
49+ Set cascade for relationships in Gtfsfeed.
50+ This allows to delete/add the relationships when their respective relation array changes.
51+ """
52+ if class_ .__name__ == "Gtfsfeed" :
53+ for rel in class_ .__mapper__ .relationships :
54+ if rel .key in [
55+ "redirectingids" ,
56+ "redirectingids_" ,
57+ "externalids" ,
58+ "externalids_" ,
59+ ]:
60+ rel .cascade = "all, delete-orphan"
61+
62+
63+ def mapper_configure_listener (mapper , class_ ):
64+ """
65+ Mapper configure listener
66+ """
67+ set_cascade (mapper , class_ )
68+ configure_polymorphic_mappers ()
69+
70+
71+ # Add the mapper_configure_listener to the mapper_configured event
72+ event .listen (mapper , "mapper_configured" , mapper_configure_listener )
73+
74+
75+ def refresh_materialized_view (session : "Session" , view_name : str ) -> bool :
76+ """
77+ Refresh Materialized view by name.
78+ @return: True if the view was refreshed successfully, False otherwise
79+ """
80+ try :
81+ session .execute (text (f"REFRESH MATERIALIZED VIEW CONCURRENTLY { view_name } " ))
82+ return True
83+ except Exception as error :
84+ logging .error (f"Error raised while refreshing view: { error } " )
85+ return False
86+
87+
88+ def with_db_session (func = None , db_url : str | None = None ):
4689 """
4790 Decorator to handle the session management for the decorated function.
4891
@@ -58,12 +101,15 @@ def with_db_session(func):
58101 exception occurs, and closed in either case.
59102 - The session is then passed to the decorated function as the 'db_session' keyword argument.
60103 - If 'db_session' is already provided, it simply calls the decorated function with the existing session.
104+ - The echoed SQL queries will be logged if the environment variable LOGGING_LEVEL is set to DEBUG.
61105 """
106+ if func is None :
107+ return lambda f : with_db_session (f , db_url = db_url )
62108
63109 def wrapper (* args , ** kwargs ):
64110 db_session = kwargs .get ("db_session" )
65111 if db_session is None :
66- db = Database ()
112+ db = Database (echo_sql = get_env_logging_level () == "DEBUG" , feeds_database_url = db_url )
67113 with db .start_db_session () as session :
68114 kwargs ["db_session" ] = session
69115 return func (* args , ** kwargs )
@@ -89,12 +135,16 @@ def __new__(cls, *args, **kwargs):
89135 cls .instance = object .__new__ (cls )
90136 return cls .instance
91137
92- def __init__ (self , echo_sql = False ):
138+ def __init__ (self , echo_sql = False , feeds_database_url : str | None = None ):
93139 """
94140 Initializes the database instance
95- :param echo_sql: whether to echo the SQL queries or not
96- echo_sql set to False reduces the amount of information and noise going to the logs.
97- In case of errors, the exceptions will still contain relevant information about the failing queries.
141+
142+ :param echo_sql: whether to echo the SQL queries or not echo_sql.
143+ False reduces the amount of information and noise going to the logs.
144+ In case of errors, the exceptions will still contain relevant information about the failing queries.
145+
146+ :param feeds_database_url: The URL of the target database.
147+ If it's None the URL will be assigned from the environment variable FEEDS_DATABASE_URL.
98148 """
99149
100150 # This init function is called each time we call Database(), but in the case of a singleton, we only want to
@@ -107,7 +157,7 @@ def __init__(self, echo_sql=False):
107157 load_dotenv ()
108158 self .logger = logging .getLogger (__name__ )
109159 self .connection_attempts = 0
110- database_url = os .getenv ("FEEDS_DATABASE_URL" )
160+ database_url = feeds_database_url if feeds_database_url else os .getenv ("FEEDS_DATABASE_URL" )
111161 if database_url is None :
112162 raise Exception ("Database URL not provided." )
113163 self .pool_size = int (os .getenv ("DB_POOL_SIZE" , 10 ))
0 commit comments