diff --git a/src/pytest_ansible_network_integration/__init__.py b/src/pytest_ansible_network_integration/__init__.py index dcc2ec1..96ed923 100644 --- a/src/pytest_ansible_network_integration/__init__.py +++ b/src/pytest_ansible_network_integration/__init__.py @@ -22,6 +22,7 @@ from .exceptions import PytestNetworkError from .utils import _github_action_log from .utils import _inventory +from .utils import _inventory_multi from .utils import _print from .utils import calculate_ports from .utils import playbook @@ -406,6 +407,138 @@ def _appliance_dhcp_address(env_vars: Dict[str, str]) -> Generator[str, None, No _github_action_log("::endgroup::") +@pytest.fixture(scope="session", name="appliance_dhcp_map") +def _appliance_dhcp_map(env_vars: Dict[str, str]) -> Generator[Dict[str, str], None, None]: + """Provision the lab and collect DHCP addresses for all appliances. + + Returns a mapping of device name to DHCP IP for all devices in the lab. + """ + _github_action_log("::group::Starting lab provisioning (multi-device)") + _print("Starting lab provisioning (multi-device)") + + try: + if not OPTIONS: + raise PytestNetworkError("Missing CML lab options") + + lab_file = OPTIONS.cml_lab + if not os.path.exists(lab_file): + raise PytestNetworkError(f"Missing lab file '{lab_file}'") + + start = time.time() + cml = CmlWrapper( + host=env_vars["cml_host"], + username=env_vars["cml_ui_user"], + password=env_vars["cml_ui_password"], + ) + cml.bring_up(file=lab_file) + lab_id = cml.current_lab_id + logger.debug("Lab ID: %s", lab_id) + + virsh = VirshWrapper( + host=env_vars["cml_host"], + user=env_vars["cml_ssh_user"], + password=env_vars["cml_ssh_password"], + port=int(env_vars["cml_ssh_port"]), + ) + + wait_extra_time = OPTIONS.wait_extra + wait_seconds = 0 + if wait_extra_time: + try: + wait_seconds = int(wait_extra_time) + except ValueError: + logger.warning( + "Invalid wait_extra value: '%s'. Expected an integer. Skipping extra wait.", + wait_extra_time, + ) + wait_seconds = 0 + + try: + device_to_ip = virsh.get_dhcp_leases(lab_id, wait_seconds) + except PytestNetworkError as exc: + logger.error("Failed to get DHCP leases for the appliances") + virsh.close() + cml.remove() + raise PytestNetworkError("Failed to get DHCP leases for the appliances") from exc + + end = time.time() + elapsed = end - start + _print(f"Elapsed time to provision (multi): {elapsed} seconds") + logger.info("Elapsed time to provision (multi): %s seconds", elapsed) + + except PytestNetworkError as exc: + logger.error("Failed to provision lab (multi): %s", exc) + _github_action_log("::endgroup::") + raise + + finally: + virsh.close() + _github_action_log("::endgroup::") + + yield device_to_ip + + _github_action_log("::group::Removing lab (multi)") + try: + cml.remove() + except PytestNetworkError as exc: + logger.error("Failed to remove lab (multi): %s", exc) + raise + finally: + _github_action_log("::endgroup::") + + +@pytest.fixture +def ansible_project_multi( + appliance_dhcp_map: Dict[str, str], + env_vars: Dict[str, str], + integration_test_path: Path, + tmp_path: Path, +) -> AnsibleProject: + """Build an Ansible project for all discovered appliances. + + Creates a multi-host inventory using all DHCP leases discovered. + """ + logger.info("Building the Ansible project for multiple devices") + + inventory = _inventory_multi( + host=env_vars["cml_host"], + device_to_ip=appliance_dhcp_map, + network_os=env_vars["network_os"], + username=env_vars["device_username"], + password=env_vars["device_password"], + ) + logger.debug("Generated multi-host inventory: %s", inventory) + + inventory_path = tmp_path / "inventory.json" + with inventory_path.open(mode="w", encoding="utf-8") as fh: + json.dump(inventory, fh) + logger.debug("Inventory written to %s", inventory_path) + + playbook_contents = playbook(hosts="all", role=str(integration_test_path)) + playbook_path = tmp_path / "site.json" + with playbook_path.open(mode="w", encoding="utf-8") as fh: + json.dump(playbook_contents, fh) + logger.debug("Playbook written to %s", playbook_path) + + _print(f"Inventory path: {inventory_path}") + _print(f"Playbook path: {playbook_path}") + + project = AnsibleProject( + collection_doc_cache=tmp_path / "collection_doc_cache.db", + directory=tmp_path, + inventory=inventory_path, + log_file=Path.home() / "test_logs" / f"{integration_test_path.name}.log", + playbook=playbook_path, + playbook_artifact=Path.home() + / "test_logs" + / "{playbook_status}" + / f"{integration_test_path.name}.json", + role=integration_test_path.name, + ) + logger.info("Ansible multi-host project created successfully") + return project + + def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: """Generate tests based on the integration test paths. diff --git a/src/pytest_ansible_network_integration/defs.py b/src/pytest_ansible_network_integration/defs.py index 5ce364f..c40a659 100644 --- a/src/pytest_ansible_network_integration/defs.py +++ b/src/pytest_ansible_network_integration/defs.py @@ -300,12 +300,55 @@ def get_dhcp_lease(self, current_lab_id: str, wait_extra: int) -> str: logger.info("Done waiting, starting to find IPs") if len(ips) > 1: - logger.error("Found more than one IP: %s", ips) + logger.error("SSSSSSSSSSSSSSSSS Found more than one IP: %s", ips) raise PytestNetworkError("Found more than one IP") logger.info("DHCP lease IP found: %s", ips[0]) return ips[0] + def get_dhcp_leases(self, current_lab_id: str, wait_extra: int) -> Dict[str, str]: + """Get DHCP leases for all devices in the specified lab. + + :param current_lab_id: The current lab ID. + :param wait_extra: Extra seconds to wait before resolving leases. + :raises PytestNetworkError: If no leases can be found. + :return: Mapping of device name to its IP address. + """ + logger.info("Getting all current lab domains from virsh") + domains = self._find_current_lab_domains(current_lab_id, 20) + + if wait_extra: + logger.info("Waiting for extra %s seconds before resolving leases", wait_extra) + time.sleep(wait_extra) + + device_to_ip: Dict[str, str] = {} + for domain in domains: + try: + device_name = domain["domain"]["name"] + except KeyError as e: + logger.error("Failed to extract device name from domain: %s", e) + raise PytestNetworkError(f"Failed to extract device name: {e}") from e + + macs = self._extract_macs(domain) + ips = self._find_dhcp_lease(macs, 200) + + if not ips: + logger.error("No IP found for device '%s'", device_name) + raise PytestNetworkError(f"No IP found for device '{device_name}'") + + if len(ips) > 1: + logger.warning( + "Multiple IPs found for device '%s' (MACs: %s), choosing first: %s", + device_name, + macs, + ips, + ) + + device_to_ip[device_name] = ips[0] + + logger.info("Resolved DHCP leases for devices: %s", device_to_ip) + return device_to_ip + def _find_current_lab(self, current_lab_id: str, max_attempts: int = 20) -> Dict[str, Any]: """Find the current lab by its ID. @@ -350,6 +393,60 @@ def _find_current_lab(self, current_lab_id: str, max_attempts: int = 20) -> Dict logger.error("Could not find current lab after %s attempts", attempt) raise PytestNetworkError("Could not find current lab") + def _find_current_lab_domains( + self, current_lab_id: str, max_attempts: int = 20 + ) -> List[Dict[str, Any]]: + """Find all domains for the current lab by its ID. + + Iterates over all virsh domains and collects those whose XML includes the + given lab ID. Retries up to max_attempts times. + + :param current_lab_id: The current lab ID. + :param max_attempts: Maximum attempts to discover lab domains. + :raises PytestNetworkError: If no domains are found for the lab. + :return: A list of domain XML dicts. + """ + attempt = 0 + while attempt < max_attempts: + logger.info("Attempt %s to find all current lab domains", attempt) + stdout, _stderr = self.ssh.execute("sudo virsh list --all") + logger.debug("virsh list output: %s", stdout) + if _stderr: + logger.error("virsh list stderr: %s", _stderr) + + virsh_matches = [re.match(r"^\s(?P\d+)", line) for line in stdout.splitlines()] + if not any(virsh_matches): + logger.error("No matching virsh IDs found in the output") + raise PytestNetworkError("No matching virsh IDs found") + + try: + virsh_ids = [ + virsh_match.groupdict()["id"] for virsh_match in virsh_matches if virsh_match + ] + except KeyError as e: + error_message = f"Failed to extract virsh IDs: {e}" + logger.error(error_message) + raise PytestNetworkError(error_message) from e + + matched_domains: List[Dict[str, Any]] = [] + for virsh_id in virsh_ids: + stdout, _stderr = self.ssh.execute(f"sudo virsh dumpxml {virsh_id}") + if current_lab_id in stdout: + logger.debug( + "Found lab %s in virsh dumpxml for ID %s", current_lab_id, virsh_id + ) + xmltodict_data = xmltodict.parse(stdout) + matched_domains.append(xmltodict_data) # type: ignore + + if matched_domains: + return matched_domains + + attempt += 1 + time.sleep(5) + + logger.error("Could not find any domains for current lab after %s attempts", attempt) + raise PytestNetworkError("Could not find any domains for current lab") + def _extract_macs(self, current_lab: Dict[str, Any]) -> List[str]: """Extract MAC addresses from the current lab. @@ -408,7 +505,7 @@ def _find_dhcp_lease(self, macs: List[str], max_attempts: int = 100) -> List[str return ips attempt += 1 - time.sleep(10) + time.sleep(400) logger.error("Could not find IPs after %s attempts", attempt) raise PytestNetworkError("Could not find IPs") diff --git a/src/pytest_ansible_network_integration/utils.py b/src/pytest_ansible_network_integration/utils.py index 2637f38..eefe335 100644 --- a/src/pytest_ansible_network_integration/utils.py +++ b/src/pytest_ansible_network_integration/utils.py @@ -5,6 +5,7 @@ from typing import Any from typing import Dict from typing import List +from typing import Mapping def _print(message: str) -> None: @@ -67,6 +68,47 @@ def _inventory( return inventory +def _inventory_multi( + host: str, + device_to_ip: Mapping[str, str], + network_os: str, + username: str, + password: str, +) -> Dict[str, Any]: + """Build an ansible inventory for multiple devices. + + :param device_to_ip: Mapping of device name to its management IP + :param network_os: The network OS + :param username: Device username + :param password: Device password + :returns: The inventory for all devices under group 'all' + """ + hosts: Dict[str, Any] = {} + + for device_name, ip_address in device_to_ip.items(): + ports = calculate_ports(ip_address) + host_key = _sanitize_host_key(device_name) + hosts[host_key] = { + "ansible_become": False, + "ansible_host": host, + "ansible_user": username, + "ansible_password": password, + "ansible_port": ports["ssh_port"], + "ansible_httpapi_port": ports["http_port"], + "ansible_connection": "ansible.netcommon.network_cli", + "ansible_network_cli_ssh_type": "libssh", + "ansible_python_interpreter": "python", + "ansible_network_import_modules": True, + } + + return {"all": {"hosts": hosts, "vars": {"ansible_network_os": network_os}}} + + +def _sanitize_host_key(name: str) -> str: + """Return a safe inventory host key from an arbitrary device name.""" + return "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in name) + + def playbook(hosts: str, role: str) -> List[Dict[str, object]]: """Return the playbook.