Skip to content

Commit debb8a9

Browse files
authored
patching command framework infrahub db patch plan/apply/revert (#6311)
* WIP patch for removing dup kind-migrated nodes following a merge * WIP patch framework * patch commands and framework * add restore method * update tests for reverting a patch * generate cli docs * update docs * remove actual patch from this branch * consolidate file names for patch plans * make elementId function use dynamic * one more docstring * refactor to track deleted IDs in a file and add failure tests * delete edges first during revert * move call order around * small changes for PR feedback * do vertex/edge adds in transactions * generate docs again * fix for reverting deleted edge linked to deleted vertex
1 parent 4d18625 commit debb8a9

File tree

20 files changed

+1609
-0
lines changed

20 files changed

+1609
-0
lines changed

backend/infrahub/cli/db.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,14 @@
5454
from infrahub.services.adapters.workflow.local import WorkflowLocalExecution
5555

5656
from .constants import ERROR_BADGE, FAILED_BADGE, SUCCESS_BADGE
57+
from .patch import patch_app
5758

5859
if TYPE_CHECKING:
5960
from infrahub.cli.context import CliContext
6061
from infrahub.database import InfrahubDatabase
6162

6263
app = AsyncTyper()
64+
app.add_typer(patch_app, name="patch")
6365

6466
PERMISSIONS_AVAILABLE = ["read", "write", "admin"]
6567

backend/infrahub/cli/patch.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
from __future__ import annotations
2+
3+
import importlib
4+
import inspect
5+
import logging
6+
from pathlib import Path
7+
from typing import TYPE_CHECKING
8+
9+
import typer
10+
from infrahub_sdk.async_typer import AsyncTyper
11+
from rich import print as rprint
12+
13+
from infrahub import config
14+
from infrahub.patch.edge_adder import PatchPlanEdgeAdder
15+
from infrahub.patch.edge_deleter import PatchPlanEdgeDeleter
16+
from infrahub.patch.edge_updater import PatchPlanEdgeUpdater
17+
from infrahub.patch.plan_reader import PatchPlanReader
18+
from infrahub.patch.plan_writer import PatchPlanWriter
19+
from infrahub.patch.queries.base import PatchQuery
20+
from infrahub.patch.runner import (
21+
PatchPlanEdgeDbIdTranslator,
22+
PatchRunner,
23+
)
24+
from infrahub.patch.vertex_adder import PatchPlanVertexAdder
25+
from infrahub.patch.vertex_deleter import PatchPlanVertexDeleter
26+
from infrahub.patch.vertex_updater import PatchPlanVertexUpdater
27+
28+
from .constants import ERROR_BADGE, SUCCESS_BADGE
29+
30+
if TYPE_CHECKING:
31+
from infrahub.cli.context import CliContext
32+
from infrahub.database import InfrahubDatabase
33+
34+
35+
patch_app = AsyncTyper(help="Commands for planning, applying, and reverting database patches")
36+
37+
38+
def get_patch_runner(db: InfrahubDatabase) -> PatchRunner:
39+
return PatchRunner(
40+
plan_writer=PatchPlanWriter(),
41+
plan_reader=PatchPlanReader(),
42+
edge_db_id_translator=PatchPlanEdgeDbIdTranslator(),
43+
vertex_adder=PatchPlanVertexAdder(db=db),
44+
vertex_deleter=PatchPlanVertexDeleter(db=db),
45+
vertex_updater=PatchPlanVertexUpdater(db=db),
46+
edge_adder=PatchPlanEdgeAdder(db=db),
47+
edge_deleter=PatchPlanEdgeDeleter(db=db),
48+
edge_updater=PatchPlanEdgeUpdater(db=db),
49+
)
50+
51+
52+
@patch_app.command(name="plan")
53+
async def plan_patch_cmd(
54+
ctx: typer.Context,
55+
patch_path: str = typer.Argument(
56+
help="Path to the file containing the PatchQuery instance to run. Use Python-style dot paths, such as infrahub.cli.patch.queries.base"
57+
),
58+
patch_plans_dir: Path = typer.Option(Path("infrahub-patches"), help="Path to patch plans directory"), # noqa: B008
59+
apply: bool = typer.Option(False, help="Apply the patch immediately after creating it"),
60+
config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
61+
) -> None:
62+
"""Create a plan for a given patch and save it in the patch plans directory to be applied/reverted"""
63+
logging.getLogger("infrahub").setLevel(logging.WARNING)
64+
logging.getLogger("neo4j").setLevel(logging.ERROR)
65+
logging.getLogger("prefect").setLevel(logging.ERROR)
66+
67+
patch_module = importlib.import_module(patch_path)
68+
patch_query_class = None
69+
patch_query_class_count = 0
70+
for _, cls in inspect.getmembers(patch_module, inspect.isclass):
71+
if issubclass(cls, PatchQuery) and cls is not PatchQuery:
72+
patch_query_class = cls
73+
patch_query_class_count += 1
74+
75+
patch_query_path = f"{PatchQuery.__module__}.{PatchQuery.__name__}"
76+
if patch_query_class is None:
77+
rprint(f"{ERROR_BADGE} No subclass of {patch_query_path} found in {patch_path}")
78+
raise typer.Exit(1)
79+
if patch_query_class_count > 1:
80+
rprint(
81+
f"{ERROR_BADGE} Multiple subclasses of {patch_query_path} found in {patch_path}. Please only define one per file."
82+
)
83+
raise typer.Exit(1)
84+
85+
config.load_and_exit(config_file_name=config_file)
86+
87+
context: CliContext = ctx.obj
88+
dbdriver = await context.init_db(retry=1)
89+
90+
patch_query_instance = patch_query_class(db=dbdriver)
91+
async with dbdriver.start_session() as db:
92+
patch_runner = get_patch_runner(db=db)
93+
patch_plan_dir = await patch_runner.prepare_plan(patch_query_instance, directory=Path(patch_plans_dir))
94+
rprint(f"{SUCCESS_BADGE} Patch plan created at {patch_plan_dir}")
95+
if apply:
96+
await patch_runner.apply(patch_plan_directory=patch_plan_dir)
97+
rprint(f"{SUCCESS_BADGE} Patch plan successfully applied")
98+
99+
await dbdriver.close()
100+
101+
102+
@patch_app.command(name="apply")
103+
async def apply_patch_cmd(
104+
ctx: typer.Context,
105+
patch_plan_dir: Path = typer.Argument(help="Path to the directory containing a patch plan"),
106+
config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
107+
) -> None:
108+
"""Apply a given patch plan"""
109+
logging.getLogger("infrahub").setLevel(logging.WARNING)
110+
logging.getLogger("neo4j").setLevel(logging.ERROR)
111+
logging.getLogger("prefect").setLevel(logging.ERROR)
112+
113+
config.load_and_exit(config_file_name=config_file)
114+
115+
context: CliContext = ctx.obj
116+
dbdriver = await context.init_db(retry=1)
117+
118+
if not patch_plan_dir.exists() or not patch_plan_dir.is_dir():
119+
rprint(f"{ERROR_BADGE} patch_plan_dir must be an existing directory")
120+
raise typer.Exit(1)
121+
122+
async with dbdriver.start_session() as db:
123+
patch_runner = get_patch_runner(db=db)
124+
await patch_runner.apply(patch_plan_directory=patch_plan_dir)
125+
rprint(f"{SUCCESS_BADGE} Patch plan successfully applied")
126+
127+
await dbdriver.close()
128+
129+
130+
@patch_app.command(name="revert")
131+
async def revert_patch_cmd(
132+
ctx: typer.Context,
133+
patch_plan_dir: Path = typer.Argument(help="Path to the directory containing a patch plan"),
134+
config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
135+
) -> None:
136+
"""Revert a given patch plan"""
137+
logging.getLogger("infrahub").setLevel(logging.WARNING)
138+
logging.getLogger("neo4j").setLevel(logging.ERROR)
139+
logging.getLogger("prefect").setLevel(logging.ERROR)
140+
config.load_and_exit(config_file_name=config_file)
141+
142+
context: CliContext = ctx.obj
143+
db = await context.init_db(retry=1)
144+
145+
if not patch_plan_dir.exists() or not patch_plan_dir.is_dir():
146+
rprint(f"{ERROR_BADGE} patch_plan_dir must be an existing directory")
147+
raise typer.Exit(1)
148+
149+
patch_runner = get_patch_runner(db=db)
150+
await patch_runner.revert(patch_plan_directory=patch_plan_dir)
151+
rprint(f"{SUCCESS_BADGE} Patch plan successfully reverted")
152+
153+
await db.close()

backend/infrahub/patch/__init__.py

Whitespace-only changes.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from enum import Enum
2+
3+
4+
class PatchPlanFilename(str, Enum):
5+
VERTICES_TO_ADD = "vertices_to_add.json"
6+
VERTICES_TO_UPDATE = "vertices_to_update.json"
7+
VERTICES_TO_DELETE = "vertices_to_delete.json"
8+
EDGES_TO_ADD = "edges_to_add.json"
9+
EDGES_TO_UPDATE = "edges_to_update.json"
10+
EDGES_TO_DELETE = "edges_to_delete.json"
11+
ADDED_DB_IDS = "added_db_ids.json"
12+
DELETED_DB_IDS = "deleted_db_ids.json"
13+
REVERTED_DELETED_DB_IDS = "reverted_deleted_db_ids.json"
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from collections import defaultdict
2+
from dataclasses import asdict
3+
from typing import AsyncGenerator
4+
5+
from infrahub.core.query import QueryType
6+
from infrahub.database import InfrahubDatabase
7+
8+
from .models import EdgeToAdd
9+
10+
11+
class PatchPlanEdgeAdder:
12+
def __init__(self, db: InfrahubDatabase, batch_size_limit: int = 1000) -> None:
13+
self.db = db
14+
self.batch_size_limit = batch_size_limit
15+
16+
async def _run_add_query(self, edge_type: str, edges_to_add: list[EdgeToAdd]) -> dict[str, str]:
17+
query = """
18+
UNWIND $edges_to_add AS edge_to_add
19+
MATCH (a) WHERE %(id_func_name)s(a) = edge_to_add.from_id
20+
MATCH (b) WHERE %(id_func_name)s(b) = edge_to_add.to_id
21+
CREATE (a)-[e:%(edge_type)s]->(b)
22+
SET e = edge_to_add.after_props
23+
RETURN edge_to_add.identifier AS abstract_id, %(id_func_name)s(e) AS db_id
24+
""" % {
25+
"edge_type": edge_type,
26+
"id_func_name": self.db.get_id_function_name(),
27+
}
28+
edges_to_add_dicts = [asdict(v) for v in edges_to_add]
29+
# use transaction to make sure we record the results before committing them
30+
try:
31+
txn_db = self.db.start_transaction()
32+
async with txn_db as txn:
33+
results = await txn.execute_query(
34+
query=query, params={"edges_to_add": edges_to_add_dicts}, type=QueryType.WRITE
35+
)
36+
abstract_to_concrete_id_map: dict[str, str] = {}
37+
for result in results:
38+
abstract_id = result.get("abstract_id")
39+
concrete_id = result.get("db_id")
40+
abstract_to_concrete_id_map[abstract_id] = concrete_id
41+
finally:
42+
await txn_db.close()
43+
return abstract_to_concrete_id_map
44+
45+
async def execute(
46+
self,
47+
edges_to_add: list[EdgeToAdd],
48+
) -> AsyncGenerator[dict[str, str], None]:
49+
"""
50+
Create edges_to_add on the database.
51+
Returns a generator that yields dictionaries mapping EdgeToAdd.identifier to the database-level ID of the newly created edge.
52+
"""
53+
edges_map_queue: dict[str, list[EdgeToAdd]] = defaultdict(list)
54+
for edge_to_add in edges_to_add:
55+
edges_map_queue[edge_to_add.edge_type].append(edge_to_add)
56+
if len(edges_map_queue[edge_to_add.edge_type]) > self.batch_size_limit:
57+
yield await self._run_add_query(
58+
edge_type=edge_to_add.edge_type,
59+
edges_to_add=edges_map_queue[edge_to_add.edge_type],
60+
)
61+
edges_map_queue[edge_to_add.edge_type] = []
62+
63+
for edge_type, edges_group in edges_map_queue.items():
64+
yield await self._run_add_query(edge_type=edge_type, edges_to_add=edges_group)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import AsyncGenerator
2+
3+
from infrahub.core.query import QueryType
4+
from infrahub.database import InfrahubDatabase
5+
6+
from .models import EdgeToDelete
7+
8+
9+
class PatchPlanEdgeDeleter:
10+
def __init__(self, db: InfrahubDatabase, batch_size_limit: int = 1000) -> None:
11+
self.db = db
12+
self.batch_size_limit = batch_size_limit
13+
14+
async def _run_delete_query(self, ids_to_delete: list[str]) -> set[str]:
15+
query = """
16+
MATCH ()-[e]-()
17+
WHERE %(id_func_name)s(e) IN $ids_to_delete
18+
DELETE e
19+
RETURN %(id_func_name)s(e) AS deleted_id
20+
""" % {"id_func_name": self.db.get_id_function_name()}
21+
results = await self.db.execute_query(
22+
query=query, params={"ids_to_delete": ids_to_delete}, type=QueryType.WRITE
23+
)
24+
deleted_ids: set[str] = set()
25+
for result in results:
26+
deleted_id = result.get("deleted_id")
27+
deleted_ids.add(deleted_id)
28+
return deleted_ids
29+
30+
async def execute(self, edges_to_delete: list[EdgeToDelete]) -> AsyncGenerator[set[str], None]:
31+
for i in range(0, len(edges_to_delete), self.batch_size_limit):
32+
ids_to_delete = [e.db_id for e in edges_to_delete[i : i + self.batch_size_limit]]
33+
yield await self._run_delete_query(ids_to_delete=ids_to_delete)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from dataclasses import asdict
2+
3+
from infrahub.core.query import QueryType
4+
from infrahub.database import InfrahubDatabase
5+
6+
from .models import EdgeToUpdate
7+
8+
9+
class PatchPlanEdgeUpdater:
10+
def __init__(self, db: InfrahubDatabase, batch_size_limit: int = 1000) -> None:
11+
self.db = db
12+
self.batch_size_limit = batch_size_limit
13+
14+
async def _run_update_query(self, edges_to_update: list[EdgeToUpdate]) -> None:
15+
query = """
16+
UNWIND $edges_to_update AS edge_to_update
17+
MATCH ()-[e]-()
18+
WHERE %(id_func_name)s(e) = edge_to_update.db_id
19+
SET e = edge_to_update.after_props
20+
""" % {"id_func_name": self.db.get_id_function_name()}
21+
await self.db.execute_query(
22+
query=query, params={"edges_to_update": [asdict(e) for e in edges_to_update]}, type=QueryType.WRITE
23+
)
24+
25+
async def execute(self, edges_to_update: list[EdgeToUpdate]) -> None:
26+
for i in range(0, len(edges_to_update), self.batch_size_limit):
27+
vertices_slice = edges_to_update[i : i + self.batch_size_limit]
28+
await self._run_update_query(edges_to_update=vertices_slice)

0 commit comments

Comments
 (0)