|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from prefect import flow |
| 4 | +from prefect.logging import get_run_logger |
| 5 | + |
| 6 | +from infrahub.context import InfrahubContext # noqa: TC001 needed for prefect flow |
| 7 | +from infrahub.core.constants import NumberPoolType |
| 8 | +from infrahub.core.manager import NodeManager |
| 9 | +from infrahub.core.protocols import CoreNumberPool |
| 10 | +from infrahub.core.registry import registry |
| 11 | +from infrahub.core.schema.attribute_parameters import NumberPoolParameters |
| 12 | +from infrahub.pools.registration import get_branches_with_schema_number_pool |
| 13 | +from infrahub.services import InfrahubServices # noqa: TC001 needed for prefect flow |
| 14 | + |
| 15 | + |
| 16 | +@flow( |
| 17 | + name="validate-schema-number-pools", |
| 18 | + flow_run_name="Validate schema number pools on {branch_name}", |
| 19 | +) |
| 20 | +async def validate_schema_number_pools( |
| 21 | + branch_name: str, # noqa: ARG001 |
| 22 | + context: InfrahubContext, # noqa: ARG001 |
| 23 | + service: InfrahubServices, |
| 24 | +) -> None: |
| 25 | + log = get_run_logger() |
| 26 | + |
| 27 | + async with service.database.start_session() as dbs: |
| 28 | + schema_number_pools = await NodeManager.query( |
| 29 | + db=dbs, schema=CoreNumberPool, filters={"pool_type__value": NumberPoolType.SCHEMA.value} |
| 30 | + ) |
| 31 | + |
| 32 | + for schema_number_pool in list(schema_number_pools): |
| 33 | + defined_on_branches = get_branches_with_schema_number_pool( |
| 34 | + kind=schema_number_pool.node.value, attribute_name=schema_number_pool.node_attribute.value |
| 35 | + ) |
| 36 | + if registry.default_branch in defined_on_branches: |
| 37 | + schema = registry.schema.get(name=schema_number_pool.node.value, branch=registry.default_branch) |
| 38 | + attribute = schema.get_attribute(name=schema_number_pool.node_attribute.value) |
| 39 | + number_pool_updated = False |
| 40 | + if isinstance(attribute.parameters, NumberPoolParameters): |
| 41 | + if schema_number_pool.start_range.value != attribute.parameters.start_range: |
| 42 | + schema_number_pool.start_range.value = attribute.parameters.start_range |
| 43 | + number_pool_updated = True |
| 44 | + if schema_number_pool.end_range.value != attribute.parameters.end_range: |
| 45 | + schema_number_pool.end_range.value = attribute.parameters.end_range |
| 46 | + number_pool_updated = True |
| 47 | + |
| 48 | + if number_pool_updated: |
| 49 | + log.info( |
| 50 | + f"Updating NumberPool={schema_number_pool.id} based on changes in the schema on {registry.default_branch}" |
| 51 | + ) |
| 52 | + await schema_number_pool.save(db=service.database) |
| 53 | + |
| 54 | + elif not defined_on_branches: |
| 55 | + log.info(f"Deleting number pool (id={schema_number_pool.id}) as it is no longer defined in the schema") |
| 56 | + await schema_number_pool.delete(db=service.database) |
0 commit comments