| 
1 | 1 | from __future__ import annotations  | 
2 | 2 | 
 
  | 
 | 3 | +import asyncio  | 
3 | 4 | from collections import defaultdict  | 
4 | 5 | from collections.abc import MutableMapping  | 
5 | 6 | from enum import Enum  | 
6 | 7 | from pathlib import Path  | 
 | 8 | +from time import sleep  | 
7 | 9 | from typing import TYPE_CHECKING, Any, Optional, TypedDict, TypeVar, Union  | 
8 | 10 | from urllib.parse import urlencode  | 
9 | 11 | 
 
  | 
 | 
22 | 24 | )  | 
23 | 25 | from .generator import InfrahubGenerator  | 
24 | 26 | from .graphql import Mutation  | 
 | 27 | +from .queries import SCHEMA_HASH_SYNC_STATUS  | 
25 | 28 | from .transforms import InfrahubTransform  | 
26 | 29 | from .utils import duplicates  | 
27 | 30 | 
 
  | 
@@ -616,15 +619,36 @@ async def all(  | 
616 | 619 | 
 
  | 
617 | 620 |         return self.cache[branch]  | 
618 | 621 | 
 
  | 
619 |  | -    async def load(self, schemas: list[dict], branch: Optional[str] = None) -> SchemaLoadResponse:  | 
 | 622 | +    async def load(  | 
 | 623 | +        self, schemas: list[dict], branch: Optional[str] = None, wait_until_converged: bool = False  | 
 | 624 | +    ) -> SchemaLoadResponse:  | 
620 | 625 |         branch = branch or self.client.default_branch  | 
621 | 626 |         url = f"{self.client.address}/api/schema/load?branch={branch}"  | 
622 | 627 |         response = await self.client._post(  | 
623 | 628 |             url=url, timeout=max(120, self.client.default_timeout), payload={"schemas": schemas}  | 
624 | 629 |         )  | 
625 | 630 | 
 
  | 
 | 631 | +        if wait_until_converged:  | 
 | 632 | +            await self.wait_until_converged(branch=branch)  | 
 | 633 | + | 
626 | 634 |         return self._validate_load_schema_response(response=response)  | 
627 | 635 | 
 
  | 
 | 636 | +    async def wait_until_converged(self, branch: Optional[str] = None) -> None:  | 
 | 637 | +        """Wait until the schema has converged on the selected branch or the timeout has been reached"""  | 
 | 638 | +        waited = 0  | 
 | 639 | +        while True:  | 
 | 640 | +            status = await self.client.execute_graphql(query=SCHEMA_HASH_SYNC_STATUS, branch_name=branch)  | 
 | 641 | +            if status["InfrahubStatus"]["summary"]["schema_hash_synced"]:  | 
 | 642 | +                self.client.log.info(f"Schema successfully converged after {waited} seconds")  | 
 | 643 | +                return  | 
 | 644 | + | 
 | 645 | +            if waited >= self.client.config.schema_converge_timeout:  | 
 | 646 | +                self.client.log.warning(f"Schema not converged after {waited} seconds, proceeding regardless")  | 
 | 647 | +                return  | 
 | 648 | + | 
 | 649 | +            waited += 1  | 
 | 650 | +            await asyncio.sleep(delay=1)  | 
 | 651 | + | 
628 | 652 |     async def check(self, schemas: list[dict], branch: Optional[str] = None) -> tuple[bool, Optional[dict]]:  | 
629 | 653 |         branch = branch or self.client.default_branch  | 
630 | 654 |         url = f"{self.client.address}/api/schema/check?branch={branch}"  | 
@@ -999,15 +1023,36 @@ def fetch(  | 
999 | 1023 | 
 
  | 
1000 | 1024 |         return nodes  | 
1001 | 1025 | 
 
  | 
1002 |  | -    def load(self, schemas: list[dict], branch: Optional[str] = None) -> SchemaLoadResponse:  | 
 | 1026 | +    def load(  | 
 | 1027 | +        self, schemas: list[dict], branch: Optional[str] = None, wait_until_converged: bool = False  | 
 | 1028 | +    ) -> SchemaLoadResponse:  | 
1003 | 1029 |         branch = branch or self.client.default_branch  | 
1004 | 1030 |         url = f"{self.client.address}/api/schema/load?branch={branch}"  | 
1005 | 1031 |         response = self.client._post(  | 
1006 | 1032 |             url=url, timeout=max(120, self.client.default_timeout), payload={"schemas": schemas}  | 
1007 | 1033 |         )  | 
1008 | 1034 | 
 
  | 
 | 1035 | +        if wait_until_converged:  | 
 | 1036 | +            self.wait_until_converged(branch=branch)  | 
 | 1037 | + | 
1009 | 1038 |         return self._validate_load_schema_response(response=response)  | 
1010 | 1039 | 
 
  | 
 | 1040 | +    def wait_until_converged(self, branch: Optional[str] = None) -> None:  | 
 | 1041 | +        """Wait until the schema has converged on the selected branch or the timeout has been reached"""  | 
 | 1042 | +        waited = 0  | 
 | 1043 | +        while True:  | 
 | 1044 | +            status = self.client.execute_graphql(query=SCHEMA_HASH_SYNC_STATUS, branch_name=branch)  | 
 | 1045 | +            if status["InfrahubStatus"]["summary"]["schema_hash_synced"]:  | 
 | 1046 | +                self.client.log.info(f"Schema successfully converged after {waited} seconds")  | 
 | 1047 | +                return  | 
 | 1048 | + | 
 | 1049 | +            if waited >= self.client.config.schema_converge_timeout:  | 
 | 1050 | +                self.client.log.warning(f"Schema not converged after {waited} seconds, proceeding regardless")  | 
 | 1051 | +                return  | 
 | 1052 | + | 
 | 1053 | +            waited += 1  | 
 | 1054 | +            sleep(1)  | 
 | 1055 | + | 
1011 | 1056 |     def check(self, schemas: list[dict], branch: Optional[str] = None) -> tuple[bool, Optional[dict]]:  | 
1012 | 1057 |         branch = branch or self.client.default_branch  | 
1013 | 1058 |         url = f"{self.client.address}/api/schema/check?branch={branch}"  | 
 | 
0 commit comments