Skip to content

Commit 195cec0

Browse files
authored
Merge pull request #4671 from opsmill/lgu-refactor-query-analyzer
Refactor QueryAnalyzer to GraphProfileGenerator for query benchmarks
2 parents 7474c8c + 501cdf2 commit 195cec0

File tree

4 files changed

+95
-98
lines changed

4 files changed

+95
-98
lines changed

backend/tests/helpers/query_benchmark/data_generator.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from abc import abstractmethod
22
from pathlib import Path
3-
from typing import Callable
3+
from typing import Callable, Optional
44

55
from rich.console import Console
66
from rich.progress import Progress
77

88
from tests.helpers.query_benchmark.db_query_profiler import (
9+
GraphProfileGenerator,
910
InfrahubDatabaseProfiler,
10-
ProfilerEnabler,
1111
)
1212

1313

@@ -36,7 +36,8 @@ async def load_data_and_profile(
3636
profile_frequency: int,
3737
graphs_output_location: Path,
3838
test_label: str,
39-
memory_profiling_rate: int = 25,
39+
graph_generator: GraphProfileGenerator,
40+
memory_profiling_rate: Optional[int] = None,
4041
) -> None:
4142
"""
4243
Loads data using the provided data generator, profiles the execution at specified loading intervals,
@@ -59,7 +60,7 @@ async def load_data_and_profile(
5960
q, r = divmod(nb_elements, profile_frequency)
6061
nb_elem_per_batch = [profile_frequency] * q + ([r] if r else [])
6162

62-
query_analyzer = data_generator.db.query_analyzer
63+
db_profiling_queries = data_generator.db
6364

6465
with Progress(console=Console(force_terminal=True)) as progress: # Need force_terminal to display with pytest
6566
task = progress.add_task(
@@ -68,12 +69,14 @@ async def load_data_and_profile(
6869

6970
for i, nb_elem_to_load in enumerate(nb_elem_per_batch):
7071
await data_generator.load_data(nb_elements=nb_elem_to_load)
71-
query_analyzer.increase_nb_elements_loaded(profile_frequency)
72+
db_profiling_queries.increase_nb_elements_loaded(nb_elem_to_load)
7273
profile_memory = i % memory_profiling_rate == 0 if memory_profiling_rate is not None else False
73-
with ProfilerEnabler(profile_memory=profile_memory, query_analyzer=query_analyzer):
74+
with db_profiling_queries.profile(profile_memory):
7475
await func_call()
7576
progress.advance(task)
7677

7778
# Remove first measurements as queries when there is no data seem always extreme
78-
query_analyzer.measurements = [m for m in query_analyzer.measurements if m.nb_elements_loaded != 0]
79-
query_analyzer.create_graphs(output_location=graphs_output_location, label=test_label)
79+
measurements = [m for m in db_profiling_queries.measurements if m.nb_elements_loaded != 0]
80+
graph_generator.create_graphs(
81+
measurements=measurements, output_location=graphs_output_location, label=test_label
82+
)

backend/tests/helpers/query_benchmark/db_query_profiler.py

Lines changed: 58 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import dataclass
33
from pathlib import Path
44
from types import TracebackType
5-
from typing import Any, Optional, Type
5+
from typing import Any, List, Optional, Self, Type
66

77
import matplotlib.pyplot as plt
88
import pandas as pd
@@ -38,53 +38,27 @@ class QueryMeasurement:
3838
memory: Optional[float] = None
3939

4040

41-
class QueryAnalyzer:
42-
name: Optional[str]
43-
measurements: list[QueryMeasurement]
44-
output_location: Path
45-
nb_elements_loaded: int
46-
profile_memory: bool
47-
profile_duration: bool
48-
49-
def __init__(self) -> None:
50-
self.reset()
51-
52-
def reset(self) -> None:
53-
self.name = None
54-
self.measurements = []
55-
self.output_location = Path.cwd()
56-
self.nb_elements_loaded = 0
57-
self.profile_duration = False
58-
self.profile_memory = False
59-
60-
def increase_nb_elements_loaded(self, increment: int) -> None:
61-
self.nb_elements_loaded += increment
62-
63-
def get_df(self) -> pd.DataFrame:
41+
class GraphProfileGenerator:
42+
def build_df_from_measuremenst(self, measurements: list[QueryMeasurement]) -> pd.DataFrame:
6443
data = {}
6544
for item in QueryMeasurement.__dataclass_fields__.keys():
66-
data[item] = [getattr(m, item) for m in self.measurements]
45+
data[item] = [getattr(m, item) for m in measurements]
6746

6847
return pd.DataFrame(data)
6948

70-
def add_measurement(self, measurement: QueryMeasurement) -> None:
71-
measurement.nb_elements_loaded = self.nb_elements_loaded
72-
self.measurements.append(measurement)
73-
74-
def create_graphs(self, output_location: Path, label: str) -> None:
75-
df = self.get_df()
49+
def create_graphs(self, measurements: List[QueryMeasurement], output_location: Path, label: str) -> None:
50+
df = self.build_df_from_measuremenst(measurements)
7651
query_names = set(df["query_name"].tolist())
7752

7853
if not output_location.exists():
7954
output_location.mkdir(parents=True)
8055

8156
for query_name in query_names:
82-
self.create_duration_graph(query_name=query_name, label=label, output_dir=output_location)
57+
self.create_duration_graph(df=df, query_name=query_name, label=label, output_dir=output_location)
8358
# self.create_memory_graph(query_name=query_name, label=label, output_dir=output_location)
8459

85-
def create_duration_graph(self, query_name: str, label: str, output_dir: Path) -> None:
60+
def create_duration_graph(self, df: pd.DataFrame, query_name: str, label: str, output_dir: Path) -> None:
8661
metric = "duration"
87-
df = self.get_df()
8862

8963
name = f"{query_name}_{metric}"
9064
plt.figure(name)
@@ -105,71 +79,45 @@ def create_duration_graph(self, query_name: str, label: str, output_dir: Path) -
10579
file_name = f"{name}.png"
10680
plt.savefig(str(output_dir / file_name), bbox_inches="tight")
10781

108-
def create_memory_graph(self, query_name: str, label: str, output_dir: Path) -> None:
109-
metric = "memory"
110-
df = self.get_df()
111-
df_query = df[(df["query_name"] == query_name) & (~df["memory"].isna())]
112-
113-
plt.figure(query_name)
114-
115-
x = df_query["nb_elements_loaded"].values
116-
y = df_query[metric].values
117-
118-
plt.plot(x, y, label=label)
119-
120-
plt.legend()
121-
122-
plt.ylabel("memory", fontsize=15)
123-
plt.title(f"Query - {query_name} | {metric}", fontsize=20)
124-
125-
file_name = f"{query_name}_{metric}.png"
12682

127-
plt.savefig(str(output_dir / file_name))
128-
129-
130-
class ProfilerEnabler:
83+
class InfrahubDatabaseProfiler(InfrahubDatabase):
84+
profiling_enabled: bool
13185
profile_memory: bool
86+
measurements: List[QueryMeasurement]
87+
nb_elements_loaded: int
13288

133-
def __init__(self, profile_memory: bool, query_analyzer: QueryAnalyzer) -> None:
134-
self.profile_memory = profile_memory
135-
self.query_analyzer = query_analyzer
136-
137-
def __enter__(self) -> None:
138-
self.query_analyzer.profile_duration = True
139-
self.query_analyzer.profile_memory = self.profile_memory
140-
141-
def __exit__(
142-
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
143-
) -> None:
144-
self.query_analyzer.profile_duration = False
145-
self.query_analyzer.profile_memory = False
146-
147-
148-
# Tricky to have it as an attribute of InfrahubDatabaseProfiler as some copies of InfrahubDatabase are made
149-
# during start_session calls.
150-
# query_analyzer = QueryAnalyzer()
151-
152-
153-
class InfrahubDatabaseProfiler(InfrahubDatabase):
154-
def __init__(self, **kwargs: Any) -> None:
89+
def __init__(
90+
self,
91+
profiling_enabled: bool = False,
92+
profile_memory: bool = False,
93+
measurements: Optional[List[QueryMeasurement]] = None,
94+
nb_elements_loaded: int = 0,
95+
**kwargs: Any,
96+
) -> None: # todo args in constructor only because of __class__ pattern
15597
super().__init__(**kwargs)
156-
self.query_analyzer = QueryAnalyzer()
98+
self.profiling_enabled = profiling_enabled
99+
self.profile_memory = profile_memory
100+
self.measurements = measurements if measurements is not None else []
101+
self.nb_elements_loaded = nb_elements_loaded
157102
# Note that any attribute added here should be added to get_context method.
158103

159104
def get_context(self) -> dict[str, Any]:
160105
ctx = super().get_context()
161-
ctx["query_analyzer"] = self.query_analyzer
106+
ctx["profiling_enabled"] = self.profiling_enabled
107+
ctx["profile_memory"] = self.profile_memory
108+
ctx["measurements"] = self.measurements
109+
ctx["nb_elements_loaded"] = self.nb_elements_loaded
162110
return ctx
163111

164112
async def execute_query_with_metadata(
165113
self, query: str, params: dict[str, Any] | None = None, name: str | None = "undefined"
166114
) -> tuple[list[Record], dict[str, Any]]:
167-
if not self.query_analyzer.profile_duration:
115+
if not self.profiling_enabled:
168116
# Profiling might be disabled to avoid capturing queries while loading data
169117
return await super().execute_query_with_metadata(query, params, name)
170118

171119
# We don't want to memory profile all queries
172-
if self.query_analyzer.profile_memory and name in self.queries_names_to_config:
120+
if self.profile_memory and name in self.queries_names_to_config:
173121
# Following call to super().execute_query_with_metadata() will use this value to set PROFILE option
174122
self.queries_names_to_config[name].profile_memory = True
175123
profile_memory = True
@@ -190,7 +138,34 @@ async def execute_query_with_metadata(
190138
memory=metadata["profile"]["args"]["GlobalMemory"] if profile_memory else None,
191139
query_name=str(name),
192140
start_time=time_start,
141+
nb_elements_loaded=self.nb_elements_loaded,
193142
)
194-
self.query_analyzer.add_measurement(measurement)
143+
self.measurements.append(measurement)
195144

196145
return response, metadata
146+
147+
def profile(self, profile_memory: bool) -> Self:
148+
"""
149+
This method allows to enable profiling of a InfrahubDatabaseProfiler instance
150+
through a context manager with this syntax:
151+
152+
`with db.profile(profile_memory=...):
153+
# run code to profile
154+
`
155+
"""
156+
157+
self.profile_memory = profile_memory
158+
return self
159+
160+
def __enter__(self) -> None:
161+
self.profiling_enabled = True
162+
self.profile_memory = self.profile_memory
163+
164+
def __exit__(
165+
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
166+
) -> None:
167+
self.profiling_enabled = False
168+
self.profile_memory = False
169+
170+
def increase_nb_elements_loaded(self, nb_elements_loaded: int) -> None:
171+
self.nb_elements_loaded += nb_elements_loaded

backend/tests/query_benchmark/conftest.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from infrahub.core.constants import BranchSupportType
77
from infrahub.core.schema import SchemaRoot
8+
from tests.helpers.query_benchmark.db_query_profiler import GraphProfileGenerator
89

910
RESULTS_FOLDER = Path(__file__).resolve().parent / "query_performance_results"
1011

@@ -54,10 +55,19 @@ async def car_person_schema_root() -> SchemaRoot:
5455
],
5556
"relationships": [
5657
{"name": "cars", "peer": "TestCar", "cardinality": "many"},
57-
{"name": "animal", "peer": "TestAnimal", "cardinality": "one"},
5858
],
5959
},
6060
],
6161
}
6262

6363
return SchemaRoot(**schema)
64+
65+
66+
@pytest.fixture(scope="session")
67+
async def graph_generator() -> GraphProfileGenerator:
68+
"""
69+
Use GraphProfileGenerator as a fixture as it may allow to properly generate graphs from
70+
distinct tests, instead of having each test managing its own display.
71+
"""
72+
73+
return GraphProfileGenerator()

backend/tests/query_benchmark/test_node_unique_attribute_constraint.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
CarGeneratorWithOwnerHavingUniqueCar,
1919
)
2020
from tests.helpers.query_benchmark.data_generator import load_data_and_profile
21-
from tests.helpers.query_benchmark.db_query_profiler import BenchmarkConfig
21+
from tests.helpers.query_benchmark.db_query_profiler import BenchmarkConfig, GraphProfileGenerator
2222
from tests.query_benchmark.conftest import RESULTS_FOLDER
2323
from tests.query_benchmark.utils import start_db_and_create_default_branch
2424

@@ -28,7 +28,12 @@
2828

2929

3030
async def benchmark_uniqueness_query(
31-
query_request, car_person_schema_root, benchmark_config: BenchmarkConfig, test_params_label: str, test_name: str
31+
query_request,
32+
car_person_schema_root,
33+
graph_generator: GraphProfileGenerator,
34+
benchmark_config: BenchmarkConfig,
35+
test_params_label: str,
36+
test_name: str,
3237
):
3338
"""
3439
Profile NodeUniqueAttributeConstraintQuery with a given query_request / configuration, using a Car generator.
@@ -68,6 +73,7 @@ async def init_and_execute():
6873
nb_elements=nb_cars,
6974
graphs_output_location=graph_output_location,
7075
test_label=test_params_label,
76+
graph_generator=graph_generator,
7177
)
7278

7379

@@ -96,14 +102,15 @@ async def init_and_execute():
96102
),
97103
],
98104
)
99-
async def test_multiple_constraints(query_request, car_person_schema_root):
105+
async def test_multiple_constraints(query_request, car_person_schema_root, graph_generator):
100106
benchmark_config = BenchmarkConfig(neo4j_runtime=Neo4jRuntime.DEFAULT, neo4j_image=NEO4J_ENTERPRISE_IMAGE)
101107
await benchmark_uniqueness_query(
102108
query_request=query_request,
103109
car_person_schema_root=car_person_schema_root,
104110
benchmark_config=benchmark_config,
105111
test_params_label=str(query_request),
106112
test_name=inspect.currentframe().f_code.co_name,
113+
graph_generator=graph_generator,
107114
)
108115

109116

@@ -115,7 +122,7 @@ async def test_multiple_constraints(query_request, car_person_schema_root):
115122
BenchmarkConfig(neo4j_runtime=Neo4jRuntime.PARALLEL, neo4j_image=NEO4J_ENTERPRISE_IMAGE),
116123
],
117124
)
118-
async def test_multiple_runtimes(benchmark_config, car_person_schema_root):
125+
async def test_multiple_runtimes(benchmark_config, car_person_schema_root, graph_generator):
119126
query_request = NodeUniquenessQueryRequest(
120127
kind="TestCar",
121128
unique_attribute_paths={
@@ -133,17 +140,18 @@ async def test_multiple_runtimes(benchmark_config, car_person_schema_root):
133140
benchmark_config=benchmark_config,
134141
test_params_label=str(benchmark_config),
135142
test_name=inspect.currentframe().f_code.co_name,
143+
graph_generator=graph_generator,
136144
)
137145

138146

139147
@pytest.mark.parametrize(
140148
"benchmark_config",
141149
[
142150
BenchmarkConfig(neo4j_runtime=Neo4jRuntime.PARALLEL, neo4j_image=NEO4J_ENTERPRISE_IMAGE, load_db_indexes=False),
143-
BenchmarkConfig(neo4j_runtime=Neo4jRuntime.PARALLEL, neo4j_image=NEO4J_ENTERPRISE_IMAGE, load_db_indexes=True),
151+
# BenchmarkConfig(neo4j_runtime=Neo4jRuntime.PARALLEL, neo4j_image=NEO4J_ENTERPRISE_IMAGE, load_db_indexes=True),
144152
],
145153
)
146-
async def test_indexes(benchmark_config, car_person_schema_root):
154+
async def test_indexes(benchmark_config, car_person_schema_root, graph_generator):
147155
query_request = NodeUniquenessQueryRequest(
148156
kind="TestCar",
149157
unique_attribute_paths={
@@ -161,4 +169,5 @@ async def test_indexes(benchmark_config, car_person_schema_root):
161169
benchmark_config=benchmark_config,
162170
test_params_label=str(benchmark_config),
163171
test_name=inspect.currentframe().f_code.co_name,
172+
graph_generator=graph_generator,
164173
)

0 commit comments

Comments
 (0)