Skip to content

Commit cffafd3

Browse files
authored
Merge pull request #4628 from opsmill/dga-20241015-queryanalyzer
Add support for context to InfrahubDatabase
2 parents b7d26f5 + d0db954 commit cffafd3

File tree

9 files changed

+623
-30
lines changed

9 files changed

+623
-30
lines changed

backend/infrahub/database/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,14 @@ def is_transaction(self) -> bool:
182182
return True
183183
return False
184184

185+
def get_context(self) -> dict[str, Any]:
186+
"""
187+
This method is meant to be overridden by subclasses in order to fill in subclass attributes
188+
to methods returning a copy of this object using self.__class__ constructor.
189+
"""
190+
191+
return {}
192+
185193
def add_schema(self, schema: SchemaBranch, name: Optional[str] = None) -> None:
186194
self._schemas[name or schema.name] = schema
187195

@@ -191,6 +199,8 @@ def start_session(self, read_only: bool = False, schemas: Optional[list[SchemaBr
191199
if read_only:
192200
session_mode = InfrahubDatabaseSessionMode.READ
193201

202+
context = self.get_context()
203+
194204
return self.__class__(
195205
mode=InfrahubDatabaseMode.SESSION,
196206
db_type=self.db_type,
@@ -199,9 +209,12 @@ def start_session(self, read_only: bool = False, schemas: Optional[list[SchemaBr
199209
driver=self._driver,
200210
session_mode=session_mode,
201211
queries_names_to_config=self.queries_names_to_config,
212+
**context,
202213
)
203214

204215
def start_transaction(self, schemas: Optional[list[SchemaBranch]] = None) -> InfrahubDatabase:
216+
context = self.get_context()
217+
205218
return self.__class__(
206219
mode=InfrahubDatabaseMode.TRANSACTION,
207220
db_type=self.db_type,
@@ -211,6 +224,7 @@ def start_transaction(self, schemas: Optional[list[SchemaBranch]] = None) -> Inf
211224
session=self._session,
212225
session_mode=self._session_mode,
213226
queries_names_to_config=self.queries_names_to_config,
227+
**context,
214228
)
215229

216230
async def session(self) -> AsyncSession:

backend/tests/helpers/query_benchmark/data_generator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from rich.console import Console
66
from rich.progress import Progress
77

8-
from tests.helpers.query_benchmark.db_query_profiler import InfrahubDatabaseProfiler, ProfilerEnabler, query_analyzer
8+
from tests.helpers.query_benchmark.db_query_profiler import InfrahubDatabaseProfiler, ProfilerEnabler, QueryAnalyzer
99

1010

1111
class DataGenerator:
@@ -33,6 +33,7 @@ async def load_data_and_profile(
3333
profile_frequency: int,
3434
graphs_output_location: Path,
3535
test_label: str,
36+
query_analyzer: QueryAnalyzer,
3637
memory_profiling_rate: int = 25,
3738
) -> None:
3839
"""
@@ -67,7 +68,7 @@ async def load_data_and_profile(
6768
await data_generator.load_data(nb_elements=nb_elem_to_load)
6869
query_analyzer.increase_nb_elements_loaded(profile_frequency)
6970
profile_memory = i % memory_profiling_rate == 0 if memory_profiling_rate is not None else False
70-
with ProfilerEnabler(profile_memory=profile_memory):
71+
with ProfilerEnabler(profile_memory=profile_memory, query_analyzer=query_analyzer):
7172
await func_call()
7273
progress.advance(task)
7374

backend/tests/helpers/query_benchmark/db_query_profiler.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,35 +147,46 @@ def create_memory_graph(self, query_name: str, label: str, output_dir: Path) ->
147147
class ProfilerEnabler:
148148
profile_memory: bool
149149

150-
def __init__(self, profile_memory: bool) -> None:
150+
def __init__(self, profile_memory: bool, query_analyzer: QueryAnalyzer) -> None:
151151
self.profile_memory = profile_memory
152+
self.query_analyzer = query_analyzer
152153

153154
def __enter__(self) -> None:
154-
query_analyzer.profile_duration = True
155-
query_analyzer.profile_memory = self.profile_memory
155+
self.query_analyzer.profile_duration = True
156+
self.query_analyzer.profile_memory = self.profile_memory
156157

157158
def __exit__(
158159
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
159160
) -> None:
160-
query_analyzer.profile_duration = False
161-
query_analyzer.profile_memory = False
161+
self.query_analyzer.profile_duration = False
162+
self.query_analyzer.profile_memory = False
162163

163164

164165
# Tricky to have it as an attribute of InfrahubDatabaseProfiler as some copies of InfrahubDatabase are made
165166
# during start_session calls.
166-
query_analyzer = QueryAnalyzer()
167+
# query_analyzer = QueryAnalyzer()
167168

168169

169170
class InfrahubDatabaseProfiler(InfrahubDatabase):
171+
def __init__(self, query_analyzer: QueryAnalyzer, **kwargs: Any) -> None:
172+
super().__init__(**kwargs)
173+
self.query_analyzer = query_analyzer
174+
# Note that any attribute added here should be added to get_context method.
175+
176+
def get_context(self) -> dict[str, Any]:
177+
ctx = super().get_context()
178+
ctx["query_analyzer"] = self.query_analyzer
179+
return ctx
180+
170181
async def execute_query_with_metadata(
171182
self, query: str, params: dict[str, Any] | None = None, name: str | None = "undefined"
172183
) -> tuple[list[Record], dict[str, Any]]:
173-
if not query_analyzer.profile_duration:
184+
if not self.query_analyzer.profile_duration:
174185
# Profiling might be disabled to avoid capturing queries while loading data
175186
return await super().execute_query_with_metadata(query, params, name)
176187

177188
# We don't want to memory profile all queries
178-
if query_analyzer.profile_memory and name in self.queries_names_to_config:
189+
if self.query_analyzer.profile_memory and name in self.queries_names_to_config:
179190
# Following call to super().execute_query_with_metadata() will use this value to set PROFILE option
180191
self.queries_names_to_config[name].profile_memory = True
181192
profile_memory = True
@@ -193,6 +204,6 @@ async def execute_query_with_metadata(
193204
query_name=str(name),
194205
start_time=time_start,
195206
)
196-
query_analyzer.add_measurement(measurement)
207+
self.query_analyzer.add_measurement(measurement)
197208

198209
return response, metadata

backend/tests/query_benchmark/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import pytest
22

33
from infrahub.core.schema import SchemaRoot
4+
from tests.helpers.query_benchmark.db_query_profiler import QueryAnalyzer
5+
6+
7+
@pytest.fixture(scope="session")
8+
def query_analyzer() -> QueryAnalyzer:
9+
return QueryAnalyzer()
410

511

612
@pytest.fixture

backend/tests/query_benchmark/test_node_manager_query.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323

2424
log = get_logger()
2525

26-
pytestmark = pytest.mark.skip("Not relevant to test this currently.")
27-
2826

2927
@pytest.mark.parametrize(
3028
"neo4j_image, neo4j_runtime",
@@ -43,14 +41,14 @@
4341
),
4442
],
4543
)
46-
async def test_query_persons(neo4j_image: str, neo4j_runtime: Neo4jRuntime, car_person_schema_root):
44+
async def test_query_persons(query_analyzer, neo4j_image: str, neo4j_runtime: Neo4jRuntime, car_person_schema_root):
4745
queries_names_to_config = {
4846
NodeGetListQuery.name: QueryConfig(neo4j_runtime=neo4j_runtime),
4947
NodeListGetAttributeQuery.name: QueryConfig(neo4j_runtime=neo4j_runtime),
5048
NodeListGetInfoQuery.name: QueryConfig(neo4j_runtime=neo4j_runtime),
5149
}
5250
db_profiling_queries, default_branch = await start_db_and_create_default_branch(
53-
neo4j_image=neo4j_image, queries_names_to_config=queries_names_to_config
51+
neo4j_image=neo4j_image, queries_names_to_config=queries_names_to_config, query_analyzer=query_analyzer
5452
)
5553

5654
registry.schema.register_schema(schema=car_person_schema_root, branch=default_branch.name)
@@ -67,6 +65,7 @@ async def test_query_persons(neo4j_image: str, neo4j_runtime: Neo4jRuntime, car_
6765
profile_frequency=50,
6866
nb_elements=1000,
6967
graphs_output_location=graph_output_location,
68+
query_analyzer=query_analyzer,
7069
test_label=f" data: {neo4j_image}" + f" runtime: {neo4j_runtime}",
7170
)
7271

@@ -88,14 +87,16 @@ async def test_query_persons(neo4j_image: str, neo4j_runtime: Neo4jRuntime, car_
8887
),
8988
],
9089
)
91-
async def test_query_persons_with_isolated_cars(neo4j_image: str, neo4j_runtime: Neo4jRuntime, car_person_schema_root):
90+
async def test_query_persons_with_isolated_cars(
91+
query_analyzer, neo4j_image: str, neo4j_runtime: Neo4jRuntime, car_person_schema_root
92+
):
9293
queries_names_to_config = {
9394
NodeGetListQuery.name: QueryConfig(neo4j_runtime=neo4j_runtime),
9495
NodeListGetAttributeQuery.name: QueryConfig(neo4j_runtime=neo4j_runtime),
9596
NodeListGetInfoQuery.name: QueryConfig(neo4j_runtime=neo4j_runtime),
9697
}
9798
db_profiling_queries, default_branch = await start_db_and_create_default_branch(
98-
neo4j_image=neo4j_image, queries_names_to_config=queries_names_to_config
99+
neo4j_image=neo4j_image, queries_names_to_config=queries_names_to_config, query_analyzer=query_analyzer
99100
)
100101

101102
registry.schema.register_schema(schema=car_person_schema_root, branch=default_branch.name)
@@ -117,6 +118,7 @@ async def test_query_persons_with_isolated_cars(neo4j_image: str, neo4j_runtime:
117118
profile_frequency=50,
118119
nb_elements=1000,
119120
graphs_output_location=graph_output_location,
121+
query_analyzer=query_analyzer,
120122
test_label=f" data: {neo4j_image}" + f" runtime: {neo4j_runtime}",
121123
)
122124

@@ -138,14 +140,16 @@ async def test_query_persons_with_isolated_cars(neo4j_image: str, neo4j_runtime:
138140
),
139141
],
140142
)
141-
async def test_query_persons_with_connected_cars(neo4j_image: str, neo4j_runtime: Neo4jRuntime, car_person_schema_root):
143+
async def test_query_persons_with_connected_cars(
144+
query_analyzer, neo4j_image: str, neo4j_runtime: Neo4jRuntime, car_person_schema_root
145+
):
142146
queries_names_to_config = {
143147
NodeGetListQuery.name: QueryConfig(neo4j_runtime=neo4j_runtime),
144148
NodeListGetAttributeQuery.name: QueryConfig(neo4j_runtime=neo4j_runtime),
145149
NodeListGetInfoQuery.name: QueryConfig(neo4j_runtime=neo4j_runtime),
146150
}
147151
db_profiling_queries, default_branch = await start_db_and_create_default_branch(
148-
neo4j_image=neo4j_image, queries_names_to_config=queries_names_to_config
152+
neo4j_image=neo4j_image, queries_names_to_config=queries_names_to_config, query_analyzer=query_analyzer
149153
)
150154

151155
registry.schema.register_schema(schema=car_person_schema_root, branch=default_branch.name)
@@ -163,5 +167,6 @@ async def test_query_persons_with_connected_cars(neo4j_image: str, neo4j_runtime
163167
profile_frequency=50,
164168
nb_elements=1000,
165169
graphs_output_location=graph_output_location,
170+
query_analyzer=query_analyzer,
166171
test_label=f" data: {neo4j_image}" + f" runtime: {neo4j_runtime}",
167172
)

backend/tests/query_benchmark/test_node_unique_attribute_constraint.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@
1515
CarGenerator,
1616
)
1717
from tests.helpers.query_benchmark.data_generator import load_data_and_profile
18-
from tests.helpers.utils import start_db_and_create_default_branch
18+
19+
from .utils import start_db_and_create_default_branch
1920

2021
RESULTS_FOLDER = Path(__file__).resolve().parent / "query_performance_results"
2122

2223
log = get_logger()
2324

24-
pytestmark = pytest.mark.skip("Not relevant to test this currently.")
25-
2625

2726
@pytest.mark.parametrize(
2827
"neo4j_image, neo4j_runtime",
@@ -42,11 +41,11 @@
4241
],
4342
)
4443
async def test_query_unique_cars_single_attribute(
45-
neo4j_image: str, neo4j_runtime: Neo4jRuntime, car_person_schema_root
44+
query_analyzer, neo4j_image: str, neo4j_runtime: Neo4jRuntime, car_person_schema_root
4645
):
4746
queries_names_to_config = {NodeUniqueAttributeConstraintQuery.name: QueryConfig(neo4j_runtime=neo4j_runtime)}
4847
db_profiling_queries, default_branch = await start_db_and_create_default_branch(
49-
neo4j_image=neo4j_image, queries_names_to_config=queries_names_to_config
48+
neo4j_image=neo4j_image, queries_names_to_config=queries_names_to_config, query_analyzer=query_analyzer
5049
)
5150

5251
# Register schema
@@ -70,5 +69,6 @@ async def test_query_unique_cars_single_attribute(
7069
profile_frequency=50,
7170
nb_elements=1000,
7271
graphs_output_location=graph_output_location,
72+
query_analyzer=query_analyzer,
7373
test_label=f" data: {neo4j_image}" + f" runtime: {neo4j_runtime}",
7474
)

backend/tests/query_benchmark/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,22 @@
66
from infrahub.core.initialization import create_default_branch, create_global_branch, create_root_node
77
from infrahub.core.schema.manager import SchemaManager
88
from infrahub.database import InfrahubDatabaseMode, QueryConfig, get_db
9-
from infrahub.lock import initialize_lock
109
from tests.helpers.constants import PORT_BOLT_NEO4J
11-
from tests.helpers.query_benchmark.db_query_profiler import InfrahubDatabaseProfiler
10+
from tests.helpers.query_benchmark.db_query_profiler import InfrahubDatabaseProfiler, QueryAnalyzer
1211
from tests.helpers.utils import start_neo4j_container
1312

1413

1514
async def start_db_and_create_default_branch(
16-
neo4j_image: str, queries_names_to_config: Optional[dict[str, QueryConfig]] = None
15+
neo4j_image: str, query_analyzer: QueryAnalyzer, queries_names_to_config: Optional[dict[str, QueryConfig]] = None
1716
) -> Tuple[InfrahubDatabaseProfiler, Branch]:
1817
neo4j_container = start_neo4j_container(neo4j_image)
1918
config.SETTINGS.database.port = int(neo4j_container.get_exposed_port(PORT_BOLT_NEO4J))
2019
db = InfrahubDatabaseProfiler(
21-
mode=InfrahubDatabaseMode.DRIVER, driver=await get_db(), queries_names_to_config=queries_names_to_config
20+
mode=InfrahubDatabaseMode.DRIVER,
21+
query_analyzer=query_analyzer,
22+
driver=await get_db(),
23+
queries_names_to_config=queries_names_to_config,
2224
)
23-
initialize_lock()
2425

2526
# Create default branch
2627
await create_root_node(db=db)

0 commit comments

Comments
 (0)