Skip to content

Commit 5ae98fb

Browse files
committed
Refactor InfrahubEventFilter
1 parent 0ccecd1 commit 5ae98fb

File tree

4 files changed

+154
-180
lines changed

4 files changed

+154
-180
lines changed

backend/infrahub/graphql/queries/event.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from infrahub.exceptions import ValidationError
99
from infrahub.graphql.types.event import EventNodes
1010
from infrahub.task_manager.event import PrefectEvent
11+
from infrahub.task_manager.models import InfrahubEventFilter
1112

1213
if TYPE_CHECKING:
1314
from datetime import datetime
@@ -41,58 +42,43 @@ async def resolve(
4142
if limit > 50:
4243
# Prefect restricts this to 50
4344
raise ValidationError(input_value="The parameter 'limit' can't be above 50")
44-
return await Events.query(
45-
info=info,
46-
branch=branches,
47-
account=account__ids,
48-
limit=limit,
49-
level=level,
50-
event_type=event_type,
45+
46+
event_filter = InfrahubEventFilter.from_filters(
47+
ids=ids,
48+
branches=branches,
49+
account__ids=account__ids,
5150
has_children=has_children,
52-
offset=offset,
51+
event_type=event_type,
5352
related_node__ids=related_node__ids,
5453
primary_node__ids=primary_node__ids,
5554
parent__ids=parent__ids,
56-
ids=ids,
5755
since=since,
5856
until=until,
57+
level=level,
58+
)
59+
60+
return await Events.query(
61+
info=info,
62+
event_filter=event_filter,
63+
limit=limit,
64+
offset=offset,
5965
)
6066

6167
@classmethod
6268
async def query(
6369
cls,
6470
info: GraphQLResolveInfo,
71+
event_filter: InfrahubEventFilter,
6572
limit: int,
6673
offset: int | None = None,
67-
level: int | None = None,
68-
has_children: bool | None = None,
69-
ids: list[str] | None = None,
70-
event_type: list[str] | None = None,
71-
related_node__ids: list[str] | None = None,
72-
primary_node__ids: list[str] | None = None,
73-
parent__ids: list[str] | None = None,
74-
branch: list[str] | None = None,
75-
account: list[str] | None = None,
76-
since: datetime | None = None,
77-
until: datetime | None = None,
7874
) -> dict[str, Any]:
7975
fields = await extract_fields_first_node(info)
8076

8177
prefect_tasks = await PrefectEvent.query(
8278
fields=fields,
83-
ids=ids,
84-
related_node__ids=related_node__ids,
85-
primary_node__ids=primary_node__ids,
86-
parent__ids=parent__ids,
87-
branch=branch,
88-
event_type=event_type,
89-
has_children=has_children,
90-
account=account,
91-
level=level,
79+
event_filter=event_filter,
9280
limit=limit,
9381
offset=offset,
94-
since=since,
95-
until=until,
9682
)
9783
return {
9884
"count": prefect_tasks.get("count", 0),
Lines changed: 4 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -1,135 +1,18 @@
11
from __future__ import annotations
22

3-
import uuid
43
from typing import TYPE_CHECKING, Any
54

65
from prefect.client.orchestration import PrefectClient, get_client
7-
from prefect.events.filters import (
8-
EventFilter,
9-
EventIDFilter,
10-
EventNameFilter,
11-
EventOccurredFilter,
12-
EventRelatedFilter,
13-
EventResourceFilter,
14-
)
156
from prefect.events.schemas.events import Event as PrefectEventModel
16-
from prefect.events.schemas.events import ResourceSpecification
177
from pydantic import BaseModel, Field, TypeAdapter
188

19-
from infrahub.core.timestamp import Timestamp
209
from infrahub.log import get_logger
2110
from infrahub.utils import get_nested_dict
2211

2312
log = get_logger()
2413

2514
if TYPE_CHECKING:
26-
from datetime import datetime
27-
28-
29-
class InfrahubEventFilter(EventFilter):
30-
matching_related: list[EventRelatedFilter] = Field(default_factory=list)
31-
32-
def add_account_filter(self, account: list[str] | None) -> None:
33-
if account:
34-
self.matching_related.append(
35-
EventRelatedFilter(
36-
labels=ResourceSpecification(
37-
{"prefect.resource.role": "infrahub.account", "infrahub.resource.id": account}
38-
)
39-
)
40-
)
41-
42-
def add_branch_filter(self, branch: list[str] | None = None) -> None:
43-
if branch:
44-
self.matching_related.append(
45-
EventRelatedFilter(
46-
labels=ResourceSpecification(
47-
{"prefect.resource.role": "infrahub.branch", "infrahub.resource.label": branch}
48-
)
49-
)
50-
)
51-
52-
def add_event_filter(self, level: int | None = None, has_children: bool | None = None) -> None:
53-
event_filter: dict[str, list[str] | str] = {}
54-
if level is not None:
55-
event_filter["infrahub.event.level"] = str(level)
56-
57-
if has_children is not None:
58-
event_filter["infrahub.event.has_children"] = str(has_children).lower()
59-
60-
if event_filter:
61-
event_filter["prefect.resource.role"] = "infrahub.event"
62-
self.matching_related.append(EventRelatedFilter(labels=ResourceSpecification(event_filter)))
63-
64-
def add_event_id_filter(self, ids: list[str] | None = None) -> None:
65-
if ids:
66-
self.id = EventIDFilter(id=[uuid.UUID(id) for id in ids])
67-
68-
def add_event_type_filter(self, event_type: list[str] | None = None) -> None:
69-
if event_type:
70-
self.event = EventNameFilter(name=event_type)
71-
72-
def add_primary_node_filter(self, primary_node__ids: list[str] | None) -> None:
73-
if primary_node__ids:
74-
self.resource = EventResourceFilter(labels=ResourceSpecification({"infrahub.node.id": primary_node__ids}))
75-
76-
def add_parent_filter(self, parent__ids: list[str] | None) -> None:
77-
if parent__ids:
78-
self.matching_related.append(
79-
EventRelatedFilter(
80-
labels=ResourceSpecification(
81-
{"prefect.resource.role": "infrahub.child_event", "infrahub.event_parent.id": parent__ids}
82-
)
83-
)
84-
)
85-
86-
def add_related_node_filter(self, related_node__ids: list[str] | None) -> None:
87-
if related_node__ids:
88-
self.matching_related.append(
89-
EventRelatedFilter(
90-
labels=ResourceSpecification(
91-
{"prefect.resource.role": "infrahub.related.node", "prefect.resource.id": related_node__ids}
92-
)
93-
)
94-
)
95-
96-
@classmethod
97-
def from_filters(
98-
cls,
99-
ids: list[str] | None = None,
100-
account: list[str] | None = None,
101-
related_node__ids: list[str] | None = None,
102-
parent__ids: list[str] | None = None,
103-
primary_node__ids: list[str] | None = None,
104-
event_type: list[str] | None = None,
105-
branch: list[str] | None = None,
106-
level: int | None = None,
107-
has_children: bool | None = None,
108-
since: datetime | None = None,
109-
until: datetime | None = None,
110-
) -> InfrahubEventFilter:
111-
occurred_filter = {}
112-
if since:
113-
occurred_filter["since"] = Timestamp(since.isoformat()).obj
114-
115-
if until:
116-
occurred_filter["until"] = Timestamp(until.isoformat()).obj
117-
118-
if occurred_filter:
119-
filters = cls(occurred=EventOccurredFilter(**occurred_filter))
120-
else:
121-
filters = cls()
122-
123-
filters.add_event_filter(level=level, has_children=has_children)
124-
filters.add_event_id_filter(ids=ids)
125-
filters.add_event_type_filter(event_type=event_type)
126-
filters.add_branch_filter(branch=branch)
127-
filters.add_account_filter(account=account)
128-
filters.add_parent_filter(parent__ids=parent__ids)
129-
filters.add_primary_node_filter(primary_node__ids=primary_node__ids)
130-
filters.add_related_node_filter(related_node__ids=related_node__ids)
131-
132-
return filters
15+
from .models import InfrahubEventFilter
13316

13417

13518
class PrefectEventData(PrefectEventModel):
@@ -260,7 +143,7 @@ async def query_events(
260143
cls,
261144
client: PrefectClient,
262145
limit: int,
263-
filters: EventFilter,
146+
filters: InfrahubEventFilter,
264147
offset: int | None = None,
265148
) -> PrefectEventResponse:
266149
body = {"limit": limit, "filter": filters.model_dump(mode="json", exclude_none=True), "offset": offset}
@@ -278,45 +161,22 @@ async def query_events(
278161
async def query(
279162
cls,
280163
fields: dict[str, Any],
164+
event_filter: InfrahubEventFilter,
281165
limit: int | None = None,
282166
offset: int | None = None,
283-
level: int | None = None,
284-
ids: list[str] | None = None,
285-
branch: list[str] | None = None,
286-
has_children: bool | None = None,
287-
account: list[str] | None = None,
288-
event_type: list[str] | None = None,
289-
related_node__ids: list[str] | None = None,
290-
primary_node__ids: list[str] | None = None,
291-
parent__ids: list[str] | None = None,
292-
since: datetime | None = None,
293-
until: datetime | None = None,
294167
) -> dict[str, Any]:
295168
nodes: list[dict] = []
296169
limit = limit or 50
297170

298171
node_fields = get_nested_dict(nested_dict=fields, keys=["edges", "node"])
299-
filters = InfrahubEventFilter.from_filters(
300-
ids=ids,
301-
branch=branch,
302-
account=account,
303-
has_children=has_children,
304-
event_type=event_type,
305-
related_node__ids=related_node__ids,
306-
primary_node__ids=primary_node__ids,
307-
parent__ids=parent__ids,
308-
since=since,
309-
until=until,
310-
level=level,
311-
)
312172

313173
if not node_fields:
314174
# This means that it's purely a count query and as such we can override the limit to avoid
315175
# returning data that will only be discarded
316176
limit = 1
317177

318178
async with get_client(sync_client=False) as client:
319-
response = await cls.query_events(client=client, filters=filters, limit=limit, offset=offset)
179+
response = await cls.query_events(client=client, filters=event_filter, limit=limit, offset=offset)
320180
nodes = [{"node": event.to_graphql()} for event in response.events]
321181

322182
return {"count": response.count, "edges": nodes}

0 commit comments

Comments
 (0)