diff --git a/changelog/+398b0883.added.md b/changelog/+398b0883.added.md new file mode 100644 index 00000000..f9554fab --- /dev/null +++ b/changelog/+398b0883.added.md @@ -0,0 +1 @@ +Added a "branch" parameter to the client.clone() method to allow properly cloning a client that targets another branch. diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index fffa8164..8d9de6d2 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -271,6 +271,11 @@ def _build_ip_prefix_allocation_query( input_data={"data": input_data}, ) + def _clone_config(self, branch: str | None = None) -> Config: + config = copy.deepcopy(self.config) + config.default_branch = branch or config.default_branch + return config + class InfrahubClient(BaseClient): """GraphQL Client to interact with Infrahub.""" @@ -847,9 +852,9 @@ async def process_non_batch() -> tuple[list[InfrahubNode], list[InfrahubNode]]: self.store.set(node=node) return nodes - def clone(self) -> InfrahubClient: + def clone(self, branch: str | None = None) -> InfrahubClient: """Return a cloned version of the client using the same configuration""" - return InfrahubClient(config=self.config) + return InfrahubClient(config=self._clone_config(branch=branch)) async def execute_graphql( self, @@ -1591,9 +1596,9 @@ def delete(self, kind: str | type[SchemaTypeSync], id: str, branch: str | None = node = InfrahubNodeSync(client=self, schema=schema, branch=branch, data={"id": id}) node.delete() - def clone(self) -> InfrahubClientSync: + def clone(self, branch: str | None = None) -> InfrahubClientSync: """Return a cloned version of the client using the same configuration""" - return InfrahubClientSync(config=self.config) + return InfrahubClientSync(config=self._clone_config(branch=branch)) def execute_graphql( self, diff --git a/infrahub_sdk/generator.py b/infrahub_sdk/generator.py index 854b3cb4..98fa689f 100644 --- a/infrahub_sdk/generator.py +++ b/infrahub_sdk/generator.py @@ -38,9 +38,7 @@ def __init__( self.params = params or {} self.root_directory = root_directory or os.getcwd() self.generator_instance = generator_instance - self._init_client = client.clone() - self._init_client.config.default_branch = self._init_client.default_branch = self.branch_name - self._init_client.store._default_branch = self.branch_name + self._init_client = client.clone(branch=self.branch_name) self._client: InfrahubClient | None = None self._nodes: list[InfrahubNode] = [] self._related_nodes: list[InfrahubNode] = [] diff --git a/infrahub_sdk/recorder.py b/infrahub_sdk/recorder.py index bf2a715a..40c45dd3 100644 --- a/infrahub_sdk/recorder.py +++ b/infrahub_sdk/recorder.py @@ -31,6 +31,9 @@ def record(response: httpx.Response) -> None: def default(cls) -> NoRecorder: return cls() + def __eq__(self, other: object) -> bool: + return isinstance(other, NoRecorder) + class JSONRecorder(BaseSettings): model_config = SettingsConfigDict(env_prefix="INFRAHUB_JSON_RECORDER_") diff --git a/tests/unit/sdk/test_client.py b/tests/unit/sdk/test_client.py index 4f7f3ac3..31c38294 100644 --- a/tests/unit/sdk/test_client.py +++ b/tests/unit/sdk/test_client.py @@ -6,6 +6,7 @@ from infrahub_sdk import InfrahubClient, InfrahubClientSync from infrahub_sdk.exceptions import NodeNotFoundError from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync +from tests.unit.sdk.conftest import BothClients pytestmark = pytest.mark.httpx_mock(can_send_already_matched_responses=True) @@ -761,12 +762,37 @@ async def test_query_echo(httpx_mock: HTTPXMock, echo_clients, client_type): @pytest.mark.parametrize("client_type", client_types) -async def test_clone(clients, client_type): +async def test_clone(clients: BothClients, client_type: str) -> None: + """Validate that the configuration of a cloned client is a replica of the original client""" if client_type == "standard": clone = clients.standard.clone() assert clone.config == clients.standard.config assert isinstance(clone, InfrahubClient) + assert clients.standard.default_branch == clone.default_branch else: clone = clients.sync.clone() assert clone.config == clients.sync.config assert isinstance(clone, InfrahubClientSync) + assert clients.sync.default_branch == clone.default_branch + + +@pytest.mark.parametrize("client_type", client_types) +async def test_clone_define_branch(clients: BothClients, client_type: str) -> None: + """Validate that the clone branch parameter sets the correct branch of the cloned client""" + clone_branch = "my_other_branch" + if client_type == "standard": + original_branch = clients.standard.default_branch + clone = clients.standard.clone(branch=clone_branch) + assert clients.standard.store._default_branch == original_branch + assert isinstance(clone, InfrahubClient) + assert clients.standard.default_branch != clone.default_branch + else: + original_branch = clients.standard.default_branch + clone = clients.sync.clone(branch="my_other_branch") + assert clients.sync.store._default_branch == original_branch + assert isinstance(clone, InfrahubClientSync) + assert clients.sync.default_branch != clone.default_branch + + assert clone.default_branch == clone_branch + assert original_branch != clone_branch + assert clone.store._default_branch == clone_branch