22
33import  asyncio 
44import  random 
5+ from  dataclasses  import  dataclass 
56from  typing  import  TYPE_CHECKING , Any , Callable , Coroutine , Optional , TypeVar , Union 
67
78from  neo4j  import  (
3031from  infrahub .log  import  get_logger 
3132from  infrahub .utils  import  InfrahubStringEnum 
3233
33- from  .constants  import  DatabaseType 
34+ from  .constants  import  DatabaseType ,  Neo4jRuntime 
3435from  .memgraph  import  DatabaseManagerMemgraph 
3536from  .metrics  import  QUERY_EXECUTION_METRICS , TRANSACTION_RETRIES 
3637from  .neo4j  import  DatabaseManagerNeo4j 
5051log  =  get_logger ()
5152
5253
54+ @dataclass  
55+ class  QueryConfig :
56+     neo4j_runtime : Neo4jRuntime  =  Neo4jRuntime .DEFAULT 
57+     profile_memory : bool  =  False 
58+ 
59+ 
5360class  InfrahubDatabaseMode (InfrahubStringEnum ):
5461    DRIVER  =  "driver" 
5562    SESSION  =  "session" 
@@ -134,13 +141,15 @@ def __init__(
134141        session : Optional [AsyncSession ] =  None ,
135142        session_mode : InfrahubDatabaseSessionMode  =  InfrahubDatabaseSessionMode .WRITE ,
136143        transaction : Optional [AsyncTransaction ] =  None ,
137-     ) ->  None :
144+         queries_names_to_config : Optional [dict [str , QueryConfig ]] =  None ,
145+     ):
138146        self ._mode : InfrahubDatabaseMode  =  mode 
139147        self ._driver : AsyncDriver  =  driver 
140148        self ._session : Optional [AsyncSession ] =  session 
141149        self ._session_mode : InfrahubDatabaseSessionMode  =  session_mode 
142150        self ._is_session_local : bool  =  False 
143151        self ._transaction : Optional [AsyncTransaction ] =  transaction 
152+         self .queries_names_to_config  =  queries_names_to_config  if  queries_names_to_config  is  not None  else  {}
144153
145154        if  schemas :
146155            self ._schemas : dict [str , SchemaBranch ] =  {schema .name : schema  for  schema  in  schemas }
@@ -189,6 +198,7 @@ def start_session(self, read_only: bool = False, schemas: Optional[list[SchemaBr
189198            db_manager = self .manager ,
190199            driver = self ._driver ,
191200            session_mode = session_mode ,
201+             queries_names_to_config = self .queries_names_to_config ,
192202        )
193203
194204    def  start_transaction (self , schemas : Optional [list [SchemaBranch ]] =  None ) ->  InfrahubDatabase :
@@ -200,6 +210,7 @@ def start_transaction(self, schemas: Optional[list[SchemaBranch]] = None) -> Inf
200210            driver = self ._driver ,
201211            session = self ._session ,
202212            session_mode = self ._session_mode ,
213+             queries_names_to_config = self .queries_names_to_config ,
203214        )
204215
205216    async  def  session (self ) ->  AsyncSession :
@@ -274,14 +285,8 @@ async def close(self) -> None:
274285    async  def  execute_query (
275286        self , query : str , params : Optional [dict [str , Any ]] =  None , name : Optional [str ] =  "undefined" 
276287    ) ->  list [Record ]:
277-         with  trace .get_tracer (__name__ ).start_as_current_span ("execute_db_query" ) as  span :
278-             span .set_attribute ("query" , query )
279-             if  name :
280-                 span .set_attribute ("query_name" , name )
281- 
282-             with  QUERY_EXECUTION_METRICS .labels (self ._session_mode .value , name ).time ():
283-                 response  =  await  self .run_query (query = query , params = params )
284-                 return  [item  async  for  item  in  response ]
288+         results , _  =  await  self .execute_query_with_metadata (query = query , params = params , name = name )
289+         return  results 
285290
286291    async  def  execute_query_with_metadata (
287292        self , query : str , params : Optional [dict [str , Any ]] =  None , name : Optional [str ] =  "undefined" 
@@ -291,6 +296,17 @@ async def execute_query_with_metadata(
291296            if  name :
292297                span .set_attribute ("query_name" , name )
293298
299+             try :
300+                 query_config  =  self .queries_names_to_config [name ]
301+                 if  self .db_type  ==  DatabaseType .NEO4J :
302+                     runtime  =  self .queries_names_to_config [name ].neo4j_runtime 
303+                     if  runtime  !=  Neo4jRuntime .DEFAULT :
304+                         query  =  f"CYPHER runtime = { runtime .value } \n "  +  query 
305+                 if  query_config .profile_memory :
306+                     query  =  "PROFILE\n "  +  query 
307+             except  KeyError :
308+                 pass   # No specific config for this query 
309+ 
294310            with  QUERY_EXECUTION_METRICS .labels (self ._session_mode .value , name ).time ():
295311                response  =  await  self .run_query (query = query , params = params , name = name )
296312                results  =  [item  async  for  item  in  response ]
0 commit comments