22from dataclasses import dataclass
33from pathlib import Path
44from types import TracebackType
5- from typing import Any , Optional , Type
5+ from typing import Any , List , Optional , Self , Type
66
77import matplotlib .pyplot as plt
88import pandas as pd
@@ -38,53 +38,27 @@ class QueryMeasurement:
3838 memory : Optional [float ] = None
3939
4040
41- class QueryAnalyzer :
42- name : Optional [str ]
43- measurements : list [QueryMeasurement ]
44- output_location : Path
45- nb_elements_loaded : int
46- profile_memory : bool
47- profile_duration : bool
48-
49- def __init__ (self ) -> None :
50- self .reset ()
51-
52- def reset (self ) -> None :
53- self .name = None
54- self .measurements = []
55- self .output_location = Path .cwd ()
56- self .nb_elements_loaded = 0
57- self .profile_duration = False
58- self .profile_memory = False
59-
60- def increase_nb_elements_loaded (self , increment : int ) -> None :
61- self .nb_elements_loaded += increment
62-
63- def get_df (self ) -> pd .DataFrame :
41+ class GraphProfileGenerator :
42+ def build_df_from_measuremenst (self , measurements : list [QueryMeasurement ]) -> pd .DataFrame :
6443 data = {}
6544 for item in QueryMeasurement .__dataclass_fields__ .keys ():
66- data [item ] = [getattr (m , item ) for m in self . measurements ]
45+ data [item ] = [getattr (m , item ) for m in measurements ]
6746
6847 return pd .DataFrame (data )
6948
70- def add_measurement (self , measurement : QueryMeasurement ) -> None :
71- measurement .nb_elements_loaded = self .nb_elements_loaded
72- self .measurements .append (measurement )
73-
74- def create_graphs (self , output_location : Path , label : str ) -> None :
75- df = self .get_df ()
49+ def create_graphs (self , measurements : List [QueryMeasurement ], output_location : Path , label : str ) -> None :
50+ df = self .build_df_from_measuremenst (measurements )
7651 query_names = set (df ["query_name" ].tolist ())
7752
7853 if not output_location .exists ():
7954 output_location .mkdir (parents = True )
8055
8156 for query_name in query_names :
82- self .create_duration_graph (query_name = query_name , label = label , output_dir = output_location )
57+ self .create_duration_graph (df = df , query_name = query_name , label = label , output_dir = output_location )
8358 # self.create_memory_graph(query_name=query_name, label=label, output_dir=output_location)
8459
85- def create_duration_graph (self , query_name : str , label : str , output_dir : Path ) -> None :
60+ def create_duration_graph (self , df : pd . DataFrame , query_name : str , label : str , output_dir : Path ) -> None :
8661 metric = "duration"
87- df = self .get_df ()
8862
8963 name = f"{ query_name } _{ metric } "
9064 plt .figure (name )
@@ -105,71 +79,45 @@ def create_duration_graph(self, query_name: str, label: str, output_dir: Path) -
10579 file_name = f"{ name } .png"
10680 plt .savefig (str (output_dir / file_name ), bbox_inches = "tight" )
10781
108- def create_memory_graph (self , query_name : str , label : str , output_dir : Path ) -> None :
109- metric = "memory"
110- df = self .get_df ()
111- df_query = df [(df ["query_name" ] == query_name ) & (~ df ["memory" ].isna ())]
112-
113- plt .figure (query_name )
114-
115- x = df_query ["nb_elements_loaded" ].values
116- y = df_query [metric ].values
117-
118- plt .plot (x , y , label = label )
119-
120- plt .legend ()
121-
122- plt .ylabel ("memory" , fontsize = 15 )
123- plt .title (f"Query - { query_name } | { metric } " , fontsize = 20 )
124-
125- file_name = f"{ query_name } _{ metric } .png"
12682
127- plt .savefig (str (output_dir / file_name ))
128-
129-
130- class ProfilerEnabler :
83+ class InfrahubDatabaseProfiler (InfrahubDatabase ):
84+ profiling_enabled : bool
13185 profile_memory : bool
86+ measurements : List [QueryMeasurement ]
87+ nb_elements_loaded : int
13288
133- def __init__ (self , profile_memory : bool , query_analyzer : QueryAnalyzer ) -> None :
134- self .profile_memory = profile_memory
135- self .query_analyzer = query_analyzer
136-
137- def __enter__ (self ) -> None :
138- self .query_analyzer .profile_duration = True
139- self .query_analyzer .profile_memory = self .profile_memory
140-
141- def __exit__ (
142- self , exc_type : Optional [Type [BaseException ]], exc_val : Optional [BaseException ], exc_tb : Optional [TracebackType ]
143- ) -> None :
144- self .query_analyzer .profile_duration = False
145- self .query_analyzer .profile_memory = False
146-
147-
148- # Tricky to have it as an attribute of InfrahubDatabaseProfiler as some copies of InfrahubDatabase are made
149- # during start_session calls.
150- # query_analyzer = QueryAnalyzer()
151-
152-
153- class InfrahubDatabaseProfiler (InfrahubDatabase ):
154- def __init__ (self , ** kwargs : Any ) -> None :
89+ def __init__ (
90+ self ,
91+ profiling_enabled : bool = False ,
92+ profile_memory : bool = False ,
93+ measurements : Optional [List [QueryMeasurement ]] = None ,
94+ nb_elements_loaded : int = 0 ,
95+ ** kwargs : Any ,
96+ ) -> None : # todo args in constructor only because of __class__ pattern
15597 super ().__init__ (** kwargs )
156- self .query_analyzer = QueryAnalyzer ()
98+ self .profiling_enabled = profiling_enabled
99+ self .profile_memory = profile_memory
100+ self .measurements = measurements if measurements is not None else []
101+ self .nb_elements_loaded = nb_elements_loaded
157102 # Note that any attribute added here should be added to get_context method.
158103
159104 def get_context (self ) -> dict [str , Any ]:
160105 ctx = super ().get_context ()
161- ctx ["query_analyzer" ] = self .query_analyzer
106+ ctx ["profiling_enabled" ] = self .profiling_enabled
107+ ctx ["profile_memory" ] = self .profile_memory
108+ ctx ["measurements" ] = self .measurements
109+ ctx ["nb_elements_loaded" ] = self .nb_elements_loaded
162110 return ctx
163111
164112 async def execute_query_with_metadata (
165113 self , query : str , params : dict [str , Any ] | None = None , name : str | None = "undefined"
166114 ) -> tuple [list [Record ], dict [str , Any ]]:
167- if not self .query_analyzer . profile_duration :
115+ if not self .profiling_enabled :
168116 # Profiling might be disabled to avoid capturing queries while loading data
169117 return await super ().execute_query_with_metadata (query , params , name )
170118
171119 # We don't want to memory profile all queries
172- if self .query_analyzer . profile_memory and name in self .queries_names_to_config :
120+ if self .profile_memory and name in self .queries_names_to_config :
173121 # Following call to super().execute_query_with_metadata() will use this value to set PROFILE option
174122 self .queries_names_to_config [name ].profile_memory = True
175123 profile_memory = True
@@ -190,7 +138,34 @@ async def execute_query_with_metadata(
190138 memory = metadata ["profile" ]["args" ]["GlobalMemory" ] if profile_memory else None ,
191139 query_name = str (name ),
192140 start_time = time_start ,
141+ nb_elements_loaded = self .nb_elements_loaded ,
193142 )
194- self .query_analyzer . add_measurement (measurement )
143+ self .measurements . append (measurement )
195144
196145 return response , metadata
146+
147+ def profile (self , profile_memory : bool ) -> Self :
148+ """
149+ This method allows to enable profiling of a InfrahubDatabaseProfiler instance
150+ through a context manager with this syntax:
151+
152+ `with db.profile(profile_memory=...):
153+ # run code to profile
154+ `
155+ """
156+
157+ self .profile_memory = profile_memory
158+ return self
159+
160+ def __enter__ (self ) -> None :
161+ self .profiling_enabled = True
162+ self .profile_memory = self .profile_memory
163+
164+ def __exit__ (
165+ self , exc_type : Optional [Type [BaseException ]], exc_val : Optional [BaseException ], exc_tb : Optional [TracebackType ]
166+ ) -> None :
167+ self .profiling_enabled = False
168+ self .profile_memory = False
169+
170+ def increase_nb_elements_loaded (self , nb_elements_loaded : int ) -> None :
171+ self .nb_elements_loaded += nb_elements_loaded
0 commit comments