Skip to content

Commit 8d4c407

Browse files
committed
fix(backend): granular locking and transaction fixes
Signed-off-by: Fatih Acar <[email protected]>
1 parent 1305343 commit 8d4c407

File tree

15 files changed

+374
-216
lines changed

15 files changed

+374
-216
lines changed

backend/infrahub/core/attribute.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def _filter_sensitive(self, value: str, filter_sensitive: bool) -> str:
581581

582582
return value
583583

584-
async def from_graphql(self, data: dict, db: InfrahubDatabase) -> bool:
584+
async def from_graphql(self, data: dict, db: InfrahubDatabase, process_pools: bool = True) -> bool:
585585
"""Update attr from GraphQL payload"""
586586

587587
changed = False
@@ -595,7 +595,8 @@ async def from_graphql(self, data: dict, db: InfrahubDatabase) -> bool:
595595
changed = True
596596
elif "from_pool" in data:
597597
self.from_pool = data["from_pool"]
598-
await self.node.handle_pool(db=db, attribute=self, errors=[])
598+
if process_pools:
599+
await self.node.handle_pool(db=db, attribute=self, errors=[])
599600
changed = True
600601

601602
if changed and self.is_from_profile:

backend/infrahub/core/constants/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"rels",
3838
"save",
3939
"hfid",
40+
"process_pools",
4041
]
4142

4243
RESERVED_ATTR_GEN_NAMES = ["type"]

backend/infrahub/core/node/__init__.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,9 @@ async def init(
314314

315315
return cls(**attrs)
316316

317-
async def handle_pool(self, db: InfrahubDatabase, attribute: BaseAttribute, errors: list) -> None:
317+
async def handle_pool(
318+
self, db: InfrahubDatabase, attribute: BaseAttribute, errors: list, allocate_resources: bool = True
319+
) -> None:
318320
"""Evaluate if a resource has been requested from a pool and apply the resource
319321
320322
This method only works on number pools, currently Integer is the only type that has the from_pool
@@ -325,7 +327,7 @@ async def handle_pool(self, db: InfrahubDatabase, attribute: BaseAttribute, erro
325327
attribute.from_pool = {"id": attribute.schema.parameters.number_pool_id}
326328
attribute.is_default = False
327329

328-
if not attribute.from_pool:
330+
if not attribute.from_pool or not allocate_resources:
329331
return
330332

331333
try:
@@ -485,7 +487,7 @@ async def handle_object_template(self, fields: dict, db: InfrahubDatabase, error
485487
elif relationship_peers := await relationship.get_peers(db=db):
486488
fields[relationship_name] = [{"id": peer_id} for peer_id in relationship_peers]
487489

488-
async def _process_fields(self, fields: dict, db: InfrahubDatabase) -> None:
490+
async def _process_fields(self, fields: dict, db: InfrahubDatabase, process_pools: bool = True) -> None:
489491
errors = []
490492

491493
if "_source" in fields.keys():
@@ -539,7 +541,7 @@ async def _process_fields(self, fields: dict, db: InfrahubDatabase) -> None:
539541
# Generate Attribute and Relationship and assign them
540542
# -------------------------------------------
541543
errors.extend(await self._process_fields_relationships(fields=fields, db=db))
542-
errors.extend(await self._process_fields_attributes(fields=fields, db=db))
544+
errors.extend(await self._process_fields_attributes(fields=fields, db=db, process_pools=process_pools))
543545

544546
if errors:
545547
raise ValidationError(errors)
@@ -576,7 +578,9 @@ async def _process_fields_relationships(self, fields: dict, db: InfrahubDatabase
576578

577579
return errors
578580

579-
async def _process_fields_attributes(self, fields: dict, db: InfrahubDatabase) -> list[ValidationError]:
581+
async def _process_fields_attributes(
582+
self, fields: dict, db: InfrahubDatabase, process_pools: bool
583+
) -> list[ValidationError]:
580584
errors: list[ValidationError] = []
581585

582586
for attr_schema in self._schema.attributes:
@@ -601,9 +605,10 @@ async def _process_fields_attributes(self, fields: dict, db: InfrahubDatabase) -
601605
)
602606
if not self._existing:
603607
attribute: BaseAttribute = getattr(self, attr_schema.name)
604-
await self.handle_pool(db=db, attribute=attribute, errors=errors)
608+
await self.handle_pool(db=db, attribute=attribute, errors=errors, allocate_resources=process_pools)
605609

606-
attribute.validate(value=attribute.value, name=attribute.name, schema=attribute.schema)
610+
if process_pools or attribute.from_pool is None:
611+
attribute.validate(value=attribute.value, name=attribute.name, schema=attribute.schema)
607612
except ValidationError as exc:
608613
errors.append(exc)
609614

@@ -731,7 +736,7 @@ async def process_label(self, db: InfrahubDatabase | None = None) -> None: # no
731736
self.label.value = " ".join([word.title() for word in self.name.value.split("_")])
732737
self.label.is_default = False
733738

734-
async def new(self, db: InfrahubDatabase, id: str | None = None, **kwargs: Any) -> Self:
739+
async def new(self, db: InfrahubDatabase, id: str | None = None, process_pools: bool = True, **kwargs: Any) -> Self:
735740
if id and not is_valid_uuid(id):
736741
raise ValidationError({"id": f"{id} is not a valid UUID"})
737742
if id:
@@ -741,7 +746,7 @@ async def new(self, db: InfrahubDatabase, id: str | None = None, **kwargs: Any)
741746

742747
self.id = id or str(UUIDT())
743748

744-
await self._process_fields(db=db, fields=kwargs)
749+
await self._process_fields(db=db, fields=kwargs, process_pools=process_pools)
745750
await self._process_macros(db=db)
746751

747752
return self
@@ -1046,15 +1051,15 @@ async def to_graphql(
10461051

10471052
return response
10481053

1049-
async def from_graphql(self, data: dict, db: InfrahubDatabase) -> bool:
1054+
async def from_graphql(self, data: dict, db: InfrahubDatabase, process_pools: bool = True) -> bool:
10501055
"""Update object from a GraphQL payload."""
10511056

10521057
changed = False
10531058

10541059
for key, value in data.items():
10551060
if key in self._attributes and isinstance(value, dict):
10561061
attribute = getattr(self, key)
1057-
changed |= await attribute.from_graphql(data=value, db=db)
1062+
changed |= await attribute.from_graphql(data=value, db=db, process_pools=process_pools)
10581063

10591064
if key in self._relationships:
10601065
rel: RelationshipManager = getattr(self, key)

backend/infrahub/core/node/create.py

Lines changed: 30 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from infrahub.core.constants import RelationshipCardinality, RelationshipKind
88
from infrahub.core.constraint.node.runner import NodeConstraintRunner
99
from infrahub.core.node import Node
10-
from infrahub.core.node.lock_utils import get_kind_lock_names_on_object_mutation
10+
from infrahub.core.node.lock_utils import get_lock_names_on_object_mutation
1111
from infrahub.core.protocols import CoreObjectTemplate
1212
from infrahub.core.schema import GenericSchema
1313
from infrahub.dependencies.registry import get_component_registry
@@ -171,45 +171,6 @@ async def _do_create_node(
171171
return obj
172172

173173

174-
async def _do_create_node_with_lock(
175-
node_class: type[Node],
176-
node_constraint_runner: NodeConstraintRunner,
177-
db: InfrahubDatabase,
178-
schema: NonGenericSchemaTypes,
179-
branch: Branch,
180-
fields_to_validate: list[str],
181-
data: dict[str, Any],
182-
at: Timestamp | None = None,
183-
) -> Node:
184-
schema_branch = registry.schema.get_schema_branch(name=branch.name)
185-
lock_names = get_kind_lock_names_on_object_mutation(
186-
kind=schema.kind, branch=branch, schema_branch=schema_branch, data=dict(data)
187-
)
188-
189-
if lock_names:
190-
async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names):
191-
return await _do_create_node(
192-
node_class=node_class,
193-
node_constraint_runner=node_constraint_runner,
194-
db=db,
195-
schema=schema,
196-
branch=branch,
197-
fields_to_validate=fields_to_validate,
198-
data=data,
199-
at=at,
200-
)
201-
return await _do_create_node(
202-
node_class=node_class,
203-
node_constraint_runner=node_constraint_runner,
204-
db=db,
205-
schema=schema,
206-
branch=branch,
207-
fields_to_validate=fields_to_validate,
208-
data=data,
209-
at=at,
210-
)
211-
212-
213174
async def create_node(
214175
data: dict[str, Any],
215176
db: InfrahubDatabase,
@@ -223,37 +184,48 @@ async def create_node(
223184
raise ValueError(f"Node of generic schema `{schema.name=}` can not be instantiated.")
224185

225186
component_registry = get_component_registry()
226-
node_constraint_runner = await component_registry.get_component(
227-
NodeConstraintRunner, db=db.start_session() if not db.is_transaction else db, branch=branch
228-
)
229187
node_class = Node
230188
if schema.kind in registry.node:
231189
node_class = registry.node[schema.kind]
232190

233191
fields_to_validate = list(data)
234-
if db.is_transaction:
235-
obj = await _do_create_node_with_lock(
236-
node_class=node_class,
237-
node_constraint_runner=node_constraint_runner,
238-
db=db,
239-
schema=schema,
240-
branch=branch,
241-
fields_to_validate=fields_to_validate,
242-
data=data,
243-
at=at,
244-
)
245-
else:
246-
async with db.start_transaction() as dbt:
247-
obj = await _do_create_node_with_lock(
192+
193+
preview_obj = await node_class.init(db=db, schema=schema, branch=branch)
194+
await preview_obj.new(db=db, process_pools=False, **data)
195+
schema_branch = db.schema.get_schema_branch(name=branch.name)
196+
lock_names = get_lock_names_on_object_mutation(node=preview_obj, schema_branch=schema_branch)
197+
198+
obj: Node
199+
async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names, metrics=False):
200+
if db.is_transaction:
201+
node_constraint_runner = await component_registry.get_component(NodeConstraintRunner, db=db, branch=branch)
202+
203+
obj = await _do_create_node(
248204
node_class=node_class,
249205
node_constraint_runner=node_constraint_runner,
250-
db=dbt,
206+
db=db,
251207
schema=schema,
252208
branch=branch,
253209
fields_to_validate=fields_to_validate,
254210
data=data,
255211
at=at,
256212
)
213+
else:
214+
async with db.start_transaction() as dbt:
215+
node_constraint_runner = await component_registry.get_component(
216+
NodeConstraintRunner, db=dbt, branch=branch
217+
)
218+
219+
obj = await _do_create_node(
220+
node_class=node_class,
221+
node_constraint_runner=node_constraint_runner,
222+
db=dbt,
223+
schema=schema,
224+
branch=branch,
225+
fields_to_validate=fields_to_validate,
226+
data=data,
227+
at=at,
228+
)
257229

258230
if await get_profile_ids(db=db, obj=obj):
259231
node_profiles_applier = NodeProfilesApplier(db=db, branch=branch)

backend/infrahub/core/node/lock_utils.py

Lines changed: 70 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import hashlib
2-
from typing import Any
2+
from typing import TYPE_CHECKING
33

4-
from infrahub.core.branch import Branch
5-
from infrahub.core.constants.infrahubkind import GENERICGROUP, GRAPHQLQUERYGROUP
4+
from infrahub.core.node import Node
65
from infrahub.core.schema import GenericSchema
76
from infrahub.core.schema.schema_branch import SchemaBranch
87

9-
KINDS_CONCURRENT_MUTATIONS_NOT_ALLOWED = [GENERICGROUP]
8+
if TYPE_CHECKING:
9+
from infrahub.core.relationship import RelationshipManager
10+
11+
12+
RESOURCE_POOL_LOCK_NAMESPACE = "resource_pool"
1013

1114

1215
def _get_kinds_to_lock_on_object_mutation(kind: str, schema_branch: SchemaBranch) -> list[str]:
@@ -43,55 +46,78 @@ def _get_kinds_to_lock_on_object_mutation(kind: str, schema_branch: SchemaBranch
4346
return kinds
4447

4548

46-
def _should_kind_be_locked_on_any_branch(kind: str, schema_branch: SchemaBranch) -> bool:
47-
"""
48-
Check whether kind or any kind generic is in KINDS_TO_LOCK_ON_ANY_BRANCH.
49-
"""
50-
51-
if kind in KINDS_CONCURRENT_MUTATIONS_NOT_ALLOWED:
52-
return True
53-
54-
node_schema = schema_branch.get(name=kind, duplicate=False)
55-
if isinstance(node_schema, GenericSchema):
56-
return False
57-
58-
for generic_kind in node_schema.inherit_from:
59-
if generic_kind in KINDS_CONCURRENT_MUTATIONS_NOT_ALLOWED:
60-
return True
61-
return False
62-
63-
6449
def _hash(value: str) -> str:
6550
# Do not use builtin `hash` for lock names as due to randomization results would differ between
6651
# different processes.
6752
return hashlib.sha256(value.encode()).hexdigest()
6853

6954

70-
def get_kind_lock_names_on_object_mutation(
71-
kind: str, branch: Branch, schema_branch: SchemaBranch, data: dict[str, Any]
72-
) -> list[str]:
55+
def get_lock_names_on_object_mutation(node: Node, schema_branch: SchemaBranch) -> list[str]:
7356
"""
74-
Return objects kind for which we want to avoid concurrent mutation (create/update). Except for some specific kinds,
75-
concurrent mutations are only allowed on non-main branch as objects validations will be performed at least when merging in main branch.
57+
Return lock names for object on which we want to avoid concurrent mutation (create/update).
58+
Lock names include kind, some generic kinds, resource pool ids, and values of attributes of corresponding uniqueness constraints.
7659
"""
7760

78-
if not branch.is_default and not _should_kind_be_locked_on_any_branch(kind=kind, schema_branch=schema_branch):
79-
return []
80-
81-
if kind == GRAPHQLQUERYGROUP:
82-
# Lock on name as well to improve performances
83-
try:
84-
name = data["name"].value
85-
return [build_object_lock_name(kind + "." + _hash(name))]
86-
except KeyError:
87-
# We might reach here if we are updating a CoreGraphQLQueryGroup without updating the name,
88-
# in which case we would not need to lock. This is not supposed to happen as current `update`
89-
# logic first fetches the node with its name.
90-
return []
91-
92-
lock_kinds = _get_kinds_to_lock_on_object_mutation(kind, schema_branch)
93-
lock_names = [build_object_lock_name(kind) for kind in lock_kinds]
94-
return lock_names
61+
lock_names: set[str] = set()
62+
63+
# Check if node is using resource manager allocation via attributes
64+
for attr_name in node.get_schema().attribute_names:
65+
attribute = getattr(node, attr_name, None)
66+
if attribute is not None and getattr(attribute, "from_pool", None) and "id" in attribute.from_pool:
67+
lock_names.add(f"{RESOURCE_POOL_LOCK_NAMESPACE}.{attribute.from_pool['id']}")
68+
69+
# Check if relationships allocate resources
70+
for rel_name in node._relationships:
71+
rel_manager: RelationshipManager = getattr(node, rel_name)
72+
for rel in rel_manager._relationships:
73+
if rel.from_pool and "id" in rel.from_pool:
74+
lock_names.add(f"{RESOURCE_POOL_LOCK_NAMESPACE}.{rel.from_pool['id']}")
75+
76+
lock_kinds = _get_kinds_to_lock_on_object_mutation(node.get_kind(), schema_branch)
77+
for kind in lock_kinds:
78+
schema = schema_branch.get(name=kind, duplicate=False)
79+
ucs = schema.uniqueness_constraints
80+
if ucs is None:
81+
continue
82+
83+
ucs_lock_names: list[str] = []
84+
uc_attributes_names = set()
85+
86+
for uc in ucs:
87+
uc_attributes_values = []
88+
# Keep only attributes constraints
89+
for field_path in uc:
90+
# Some attributes may exist in different uniqueness constraints, we de-duplicate them
91+
if field_path in uc_attributes_names:
92+
continue
93+
94+
# Exclude relationships uniqueness constraints
95+
schema_path = schema.parse_schema_path(path=field_path, schema=schema_branch)
96+
if schema_path.related_schema is not None or schema_path.attribute_schema is None:
97+
continue
98+
99+
uc_attributes_names.add(field_path)
100+
attr = getattr(node, schema_path.attribute_schema.name, None)
101+
if attr is None or attr.value is None:
102+
# `attr.value` being None corresponds to optional unique attribute.
103+
# `attr` being None is not supposed to happen.
104+
value_hashed = _hash("")
105+
else:
106+
value_hashed = _hash(str(attr.value))
107+
108+
uc_attributes_values.append(value_hashed)
109+
110+
if uc_attributes_values:
111+
uc_lock_name = ".".join(uc_attributes_values)
112+
ucs_lock_names.append(uc_lock_name)
113+
114+
if not ucs_lock_names:
115+
continue
116+
117+
partial_lock_name = kind + "." + ".".join(ucs_lock_names)
118+
lock_names.add(build_object_lock_name(partial_lock_name))
119+
120+
return sorted(lock_names)
95121

96122

97123
def build_object_lock_name(name: str) -> str:

0 commit comments

Comments
 (0)