diff --git a/docs/cli.md b/docs/cli.md index 691776c35..b70786000 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -43,6 +43,21 @@ By default pyinfra only prints high level information (this host connected, this + `-vv`: as above plus print shell input to the remote host + `-vvv` as above plus print shell output from the remote host +### Retry Options + +pyinfra supports automatic retry of failed operations via CLI options: + ++ `--retry N`: Retry failed operations up to N times (default: 0) ++ `--retry-delay N`: Wait N seconds between retry attempts (default: 5) + +```sh +# Retry failed operations up to 3 times with default 5 second delay +pyinfra inventory.py deploy.py --retry 3 + +# Retry with custom delay +pyinfra inventory.py deploy.py --retry 2 --retry-delay 10 +``` + ## Inventory diff --git a/docs/faq.rst b/docs/faq.rst index 34842f01e..b8f15d0b0 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -49,3 +49,32 @@ Use the LINK ``files.file``, ``files.directory`` or ``files.link`` operations to group="pyinfra", mode=644, ) + +How do I handle unreliable operations or network issues? +-------------------------------------------------------- + +Use the `retry behavior arguments `_ to automatically retry failed operations. This is especially useful for network operations or services that may be temporarily unavailable: + +.. code:: python + + # Retry a network operation up to 3 times + server.shell( + name="Download file with retries", + commands=["wget https://example.com/file.zip"], + _retries=3, + _retry_delay=5, # wait 5 seconds between retries + ) + + # Use custom retry logic for specific error conditions + def should_retry_download(output_data): + # Retry only on temporary network errors, not permanent failures + stderr_text = " ".join(output_data["stderr_lines"]).lower() + temporary_errors = ["timeout", "connection refused", "temporary failure"] + return any(error in stderr_text for error in temporary_errors) + + server.shell( + name="Download with smart retry logic", + commands=["wget https://example.com/large-file.zip"], + _retries=3, + _retry_until=should_retry_download, + ) diff --git a/docs/using-operations.rst b/docs/using-operations.rst index fa6ecd4b4..66aab3e0e 100644 --- a/docs/using-operations.rst +++ b/docs/using-operations.rst @@ -49,6 +49,45 @@ Global arguments are covered in detail here: :doc:`arguments`. There is a set of _sudo_user="pyinfra", ) +Retry Functionality +------------------- + +Operations can be configured to retry automatically on failure using retry arguments: + +.. code:: python + + from pyinfra.operations import server + + # Retry a flaky command up to 3 times with default 5 second delay + server.shell( + name="Download file with retries", + commands=["curl -o /tmp/file.tar.gz https://example.com/file.tar.gz"], + _retries=3, + ) + + # Retry with custom delay between attempts + server.shell( + name="Check service status with retries", + commands=["systemctl is-active myservice"], + _retries=2, + _retry_delay=10, # 10 second delay between retries + ) + + # Use custom retry condition to control when to retry + def retry_on_network_error(output_data): + # Retry if stderr contains network-related errors + for line in output_data["stderr_lines"]: + if any(keyword in line.lower() for keyword in ["network", "timeout", "connection"]): + return True + return False + + server.shell( + name="Network operation with conditional retry", + commands=["wget https://example.com/large-file.zip"], + _retries=5, + _retry_until=retry_on_network_error, + ) + The ``host`` Object ------------------- diff --git a/pyinfra/api/arguments.py b/pyinfra/api/arguments.py index ca217fca9..a49401774 100644 --- a/pyinfra/api/arguments.py +++ b/pyinfra/api/arguments.py @@ -72,6 +72,11 @@ class ConnectorArguments(TypedDict, total=False): _get_pty: bool _stdin: Union[str, Iterable[str]] + # Retry arguments + _retries: int + _retry_delay: Union[int, float] + _retry_until: Optional[Callable[[dict], bool]] + def generate_env(config: "Config", value: dict) -> dict: env = config.ENV.copy() @@ -232,11 +237,28 @@ def all_global_arguments() -> List[tuple[str, Type]]: return list(get_type_hints(AllArguments).items()) +# Create a dictionary for retry arguments +retry_argument_meta: dict[str, ArgumentMeta] = { + "_retries": ArgumentMeta( + "Number of times to retry failed operations.", + default=lambda config: config.RETRY, + ), + "_retry_delay": ArgumentMeta( + "Delay in seconds between retry attempts.", + default=lambda config: config.RETRY_DELAY, + ), + "_retry_until": ArgumentMeta( + "Callable taking output data that returns True to continue retrying.", + default=lambda config: None, + ), +} + all_argument_meta: dict[str, ArgumentMeta] = { **auth_argument_meta, **shell_argument_meta, **meta_argument_meta, **execution_argument_meta, + **retry_argument_meta, # Add retry arguments } EXECUTION_KWARG_KEYS = list(ExecutionArguments.__annotations__.keys()) @@ -286,6 +308,45 @@ def all_global_arguments() -> List[tuple[str, Type]]: ), "Operation meta & callbacks": (meta_argument_meta, "", ""), "Execution strategy": (execution_argument_meta, "", ""), + "Retry behavior": ( + retry_argument_meta, + """ + Retry arguments allow you to automatically retry operations that fail. You can specify + how many times to retry, the delay between retries, and optionally a condition + function to determine when to stop retrying. + """, + """ + .. code:: python + + # Retry a command up to 3 times with the default 5 second delay + server.shell( + name="Run flaky command with retries", + commands=["flaky_command"], + _retries=3, + ) + # Retry with a custom delay + server.shell( + name="Run flaky command with custom delay", + commands=["flaky_command"], + _retries=2, + _retry_delay=10, # 10 second delay between retries + ) + # Retry with a custom condition + def retry_on_specific_error(output_data): + # Retry if stderr contains "temporary failure" + for line in output_data["stderr_lines"]: + if "temporary failure" in line.lower(): + return True + return False + + server.shell( + name="Run command with conditional retry", + commands=["flaky_command"], + _retries=5, + _retry_until=retry_on_specific_error, + ) + """, + ), } diff --git a/pyinfra/api/config.py b/pyinfra/api/config.py index 3a8a4f6a6..e2013afe2 100644 --- a/pyinfra/api/config.py +++ b/pyinfra/api/config.py @@ -53,6 +53,12 @@ class ConfigDefaults: IGNORE_ERRORS: bool = False # Shell to use to execute commands SHELL: str = "sh" + # Whether to display full diffs for files + DIFF: bool = False + # Number of times to retry failed operations + RETRY: int = 0 + # Delay in seconds between retry attempts + RETRY_DELAY: int = 5 config_defaults = {key: value for key, value in ConfigDefaults.__dict__.items() if key.isupper()} diff --git a/pyinfra/api/connect.py b/pyinfra/api/connect.py index bbe345dbc..ca119487e 100644 --- a/pyinfra/api/connect.py +++ b/pyinfra/api/connect.py @@ -46,5 +46,22 @@ def connect_all(state: "State"): def disconnect_all(state: "State"): - for host in state.activated_hosts: # only hosts we connected to please! - host.disconnect() # normally a noop + """ + Disconnect from all of the configured servers in parallel. Reads/writes state.inventory. + + Args: + state (``pyinfra.api.State`` obj): the state containing an inventory to connect to + """ + greenlet_to_host = { + state.pool.spawn(host.disconnect): host + for host in state.activated_hosts # only hosts we connected to please! + } + + with progress_spinner(greenlet_to_host.values()) as progress: + for greenlet in gevent.iwait(greenlet_to_host.keys()): + host = greenlet_to_host[greenlet] + progress(host) + + for greenlet, host in greenlet_to_host.items(): + # Raise any unexpected exception + greenlet.get() diff --git a/pyinfra/api/host.py b/pyinfra/api/host.py index ab8497c56..a42742306 100644 --- a/pyinfra/api/host.py +++ b/pyinfra/api/host.py @@ -328,6 +328,10 @@ def _get_temp_directory(self): return temp_directory + def get_temp_dir_config(self): + + return self.state.config.TEMP_DIR or self.state.config.DEFAULT_TEMP_DIR + def get_temp_filename( self, hash_key: Optional[str] = None, diff --git a/pyinfra/api/operation.py b/pyinfra/api/operation.py index 6eabae425..2911f63e5 100644 --- a/pyinfra/api/operation.py +++ b/pyinfra/api/operation.py @@ -47,6 +47,9 @@ class OperationMeta: _commands: Optional[list[Any]] = None _maybe_is_change: Optional[bool] = None _success: Optional[bool] = None + _retry_attempts: int = 0 + _max_retries: int = 0 + _retry_succeeded: Optional[bool] = None def __init__(self, hash, is_change: Optional[bool]): self._hash = hash @@ -59,9 +62,17 @@ def __repr__(self) -> str: """ if self._commands is not None: + retry_info = "" + if self._retry_attempts > 0: + retry_result = "succeeded" if self._retry_succeeded else "failed" + retry_info = ( + f", retries={self._retry_attempts}/{self._max_retries} ({retry_result})" + ) + return ( "OperationMeta(executed=True, " - f"success={self.did_succeed()}, hash={self._hash}, commands={len(self._commands)})" + f"success={self.did_succeed()}, hash={self._hash}, " + f"commands={len(self._commands)}{retry_info})" ) return ( "OperationMeta(executed=False, " @@ -74,12 +85,20 @@ def set_complete( success: bool, commands: list[Any], combined_output: "CommandOutput", + retry_attempts: int = 0, + max_retries: int = 0, ) -> None: if self.is_complete(): raise RuntimeError("Cannot complete an already complete operation") self._success = success self._commands = commands self._combined_output = combined_output + self._retry_attempts = retry_attempts + self._max_retries = max_retries + + # Determine if operation succeeded after retries + if retry_attempts > 0: + self._retry_succeeded = success def is_complete(self) -> bool: return self._success is not None @@ -150,6 +169,40 @@ def stdout(self) -> str: def stderr(self) -> str: return "\n".join(self.stderr_lines) + @property + def retry_attempts(self) -> int: + return self._retry_attempts + + @property + def max_retries(self) -> int: + return self._max_retries + + @property + def was_retried(self) -> bool: + """ + Returns whether this operation was retried at least once. + """ + return self._retry_attempts > 0 + + @property + def retry_succeeded(self) -> Optional[bool]: + """ + Returns whether this operation succeeded after retries. + Returns None if the operation was not retried. + """ + return self._retry_succeeded + + def get_retry_info(self) -> dict[str, Any]: + """ + Returns a dictionary with all retry-related information. + """ + return { + "retry_attempts": self._retry_attempts, + "max_retries": self._max_retries, + "was_retried": self.was_retried, + "retry_succeeded": self._retry_succeeded, + } + def add_op(state: State, op_func, *args, **kwargs): """ diff --git a/pyinfra/api/operations.py b/pyinfra/api/operations.py index 706e1b24f..415bcb05b 100644 --- a/pyinfra/api/operations.py +++ b/pyinfra/api/operations.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time import traceback from itertools import product from socket import error as socket_error, timeout as timeout_error @@ -66,6 +67,11 @@ def _run_host_op(state: "State", host: "Host", op_hash: str) -> Optional[bool]: continue_on_error = global_arguments["_continue_on_error"] timeout = global_arguments.get("_timeout", 0) + # Extract retry arguments + retries = global_arguments.get("_retries", 0) + retry_delay = global_arguments.get("_retry_delay", 5) + retry_until = global_arguments.get("_retry_until", None) + executor_kwarg_keys = CONNECTOR_ARGUMENT_KEYS # See: https://github.com/python/mypy/issues/10371 base_connector_arguments: ConnectorArguments = cast( @@ -73,67 +79,114 @@ def _run_host_op(state: "State", host: "Host", op_hash: str) -> Optional[bool]: {key: global_arguments[key] for key in executor_kwarg_keys if key in global_arguments}, # type: ignore[literal-required] # noqa ) + retry_attempt = 0 did_error = False executed_commands = 0 - commands = [] + commands: list[PyinfraCommand] = [] all_output_lines: list[OutputLine] = [] - for command in op_data.command_generator(): - commands.append(command) - - status = False - - connector_arguments = base_connector_arguments.copy() - connector_arguments.update(command.connector_arguments) - - if not isinstance(command, PyinfraCommand): - raise TypeError("{0} is an invalid pyinfra command!".format(command)) - - if isinstance(command, FunctionCommand): - try: - status = command.execute(state, host, connector_arguments) - except Exception as e: - # Custom functions could do anything, so expect anything! - logger.warning(traceback.format_exc()) - host.log_styled( - f"Unexpected error in Python callback: {format_exception(e)}", - fg="red", - log_func=logger.warning, - ) - - elif isinstance(command, StringCommand): - output_lines = CommandOutput([]) - try: - status, output_lines = command.execute( - state, - host, - connector_arguments, - ) - except (timeout_error, socket_error, SSHException) as e: - log_host_command_error(host, e, timeout=timeout) - all_output_lines.extend(output_lines) - # If we failed and have not already printed the stderr, print it - if status is False and not state.print_output: - print_host_combined_output(host, output_lines) - - else: - try: - status = command.execute(state, host, connector_arguments) - except (timeout_error, socket_error, SSHException, IOError) as e: - log_host_command_error(host, e, timeout=timeout) - - # Break the loop to trigger a failure - if status is False: - did_error = True - if continue_on_error is True: - continue - break + # Retry loop + while retry_attempt <= retries: + did_error = False + executed_commands = 0 + commands = [] + all_output_lines = [] + + for command in op_data.command_generator(): + commands.append(command) + status = False + connector_arguments = base_connector_arguments.copy() + connector_arguments.update(command.connector_arguments) + + if not isinstance(command, PyinfraCommand): + raise TypeError("{0} is an invalid pyinfra command!".format(command)) + + if isinstance(command, FunctionCommand): + try: + status = command.execute(state, host, connector_arguments) + except Exception as e: + # Custom functions could do anything, so expect anything! + logger.warning(traceback.format_exc()) + host.log_styled( + f"Unexpected error in Python callback: {format_exception(e)}", + fg="red", + log_func=logger.warning, + ) + + elif isinstance(command, StringCommand): + output_lines = CommandOutput([]) + try: + status, output_lines = command.execute( + state, + host, + connector_arguments, + ) + except (timeout_error, socket_error, SSHException) as e: + log_host_command_error(host, e, timeout=timeout) + all_output_lines.extend(output_lines) + # If we failed and have not already printed the stderr, print it + if status is False and not state.print_output: + print_host_combined_output(host, output_lines) + + else: + try: + status = command.execute(state, host, connector_arguments) + except (timeout_error, socket_error, SSHException, IOError) as e: + log_host_command_error(host, e, timeout=timeout) + + # Break the loop to trigger a failure + if status is False: + did_error = True + if continue_on_error is True: + continue + break + + executed_commands += 1 + + # Check if we should retry + should_retry = False + if retry_attempt < retries: + # Retry on error + if did_error: + should_retry = True + # Retry on condition if no error + elif retry_until and not did_error: + try: + output_data = { + "stdout_lines": [ + line.line for line in all_output_lines if line.buffer_name == "stdout" + ], + "stderr_lines": [ + line.line for line in all_output_lines if line.buffer_name == "stderr" + ], + "commands": [str(command) for command in commands], + "executed_commands": executed_commands, + "host": host.name, + "operation": ", ".join(state.get_op_meta(op_hash).names) or "Operation", + } + should_retry = retry_until(output_data) + except Exception as e: + host.log_styled( + f"Error in retry_until function: {format_exception(e)}", + fg="red", + log_func=logger.warning, + ) + + if should_retry: + retry_attempt += 1 + state.trigger_callbacks("operation_host_retry", host, op_hash, retry_attempt, retries) + op_name = ", ".join(state.get_op_meta(op_hash).names) or "Operation" + host.log_styled( + f"Retrying {op_name} (attempt {retry_attempt}/{retries}) after {retry_delay}s...", + fg="yellow", + log_func=logger.info, + ) + time.sleep(retry_delay) + continue - executed_commands += 1 + break # Handle results - # - op_success = return_status = not did_error host_results = state.get_results_for_host(host) @@ -142,10 +195,13 @@ def _run_host_op(state: "State", host: "Host", op_hash: str) -> Optional[bool]: host_results.success_ops += 1 _status_log = "Success" if executed_commands > 0 else "No changes" + if retry_attempt > 0: + _status_log = f"{_status_log} on retry {retry_attempt}" + _click_log_status = click.style(_status_log, "green") logger.info("{0}{1}".format(host.print_prefix, _click_log_status)) - state.trigger_callbacks("operation_host_success", host, op_hash) + state.trigger_callbacks("operation_host_success", host, op_hash, retry_attempt) else: if ignore_errors: host_results.ignored_error_ops += 1 @@ -156,6 +212,11 @@ def _run_host_op(state: "State", host: "Host", op_hash: str) -> Optional[bool]: host_results.partial_ops += 1 _command_description = f"executed {executed_commands} commands" + if retry_attempt > 0: + _command_description = ( + f"{_command_description} (failed after {retry_attempt}/{retries} retries)" + ) + log_error_or_warning(host, ignore_errors, _command_description, continue_on_error) # Ignored, op "completes" w/ ignored error @@ -164,12 +225,14 @@ def _run_host_op(state: "State", host: "Host", op_hash: str) -> Optional[bool]: return_status = True # Unignored error -> False - state.trigger_callbacks("operation_host_error", host, op_hash) + state.trigger_callbacks("operation_host_error", host, op_hash, retry_attempt, retries) op_data.operation_meta.set_complete( op_success, commands, CommandOutput(all_output_lines), + retry_attempts=retry_attempt, + max_retries=retries, ) return return_status diff --git a/pyinfra/api/state.py b/pyinfra/api/state.py index ab4af5bdb..7937c09fb 100644 --- a/pyinfra/api/state.py +++ b/pyinfra/api/state.py @@ -70,11 +70,19 @@ def operation_host_start(state: "State", host: "Host", op_hash): pass @staticmethod - def operation_host_success(state: "State", host: "Host", op_hash): + def operation_host_success(state: "State", host: "Host", op_hash, retry_count: int = 0): pass @staticmethod - def operation_host_error(state: "State", host: "Host", op_hash): + def operation_host_error( + state: "State", host: "Host", op_hash, retry_count: int = 0, max_retries: int = 0 + ): + pass + + @staticmethod + def operation_host_retry( + state: "State", host: "Host", op_hash, retry_num: int, max_retries: int + ): pass @staticmethod diff --git a/pyinfra/connectors/scp/__init__.py b/pyinfra/connectors/scp/__init__.py new file mode 100644 index 000000000..4652f5cca --- /dev/null +++ b/pyinfra/connectors/scp/__init__.py @@ -0,0 +1 @@ +from .client import SCPClient # noqa: F401 diff --git a/pyinfra/connectors/scp/client.py b/pyinfra/connectors/scp/client.py new file mode 100644 index 000000000..d57bfcd41 --- /dev/null +++ b/pyinfra/connectors/scp/client.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +import ntpath +import os +from pathlib import PurePath +from shlex import quote +from socket import timeout as SocketTimeoutError +from typing import IO, AnyStr + +from paramiko import Channel +from paramiko.transport import Transport + +SCP_COMMAND = b"scp" + + +# Unicode conversion functions; assume UTF-8 +def asbytes(s: bytes | str | PurePath) -> bytes: + """Turns unicode into bytes, if needed. + + Assumes UTF-8. + """ + if isinstance(s, bytes): + return s + elif isinstance(s, PurePath): + return bytes(s) + else: + return s.encode("utf-8") + + +def asunicode(s: bytes | str) -> str: + """Turns bytes into unicode, if needed. + + Uses UTF-8. + """ + if isinstance(s, bytes): + return s.decode("utf-8", "replace") + else: + return s + + +class SCPClient: + """ + An scp1 implementation, compatible with openssh scp. + Raises SCPException for all transport related errors. Local filesystem + and OS errors pass through. + + Main public methods are .putfo and .getfo + """ + + def __init__( + self, + transport: Transport, + buff_size: int = 16384, + socket_timeout: float = 10.0, + ): + self.transport = transport + self.buff_size = buff_size + self.socket_timeout = socket_timeout + self._channel: Channel | None = None + self.scp_command = SCP_COMMAND + + @property + def channel(self) -> Channel: + """Return an open Channel, (re)opening if needed.""" + if self._channel is None or self._channel.closed: + self._channel = self.transport.open_session() + return self._channel + + def __enter__(self): + _ = self.channel # triggers opening if not already open + return self + + def __exit__(self, type, value, traceback): + self.close() + + def putfo( + self, + fl: IO[AnyStr], + remote_path: str | bytes, + mode: str | bytes = "0644", + size: int | None = None, + ) -> None: + if size is None: + pos = fl.tell() + fl.seek(0, os.SEEK_END) # Seek to end + size = fl.tell() - pos + fl.seek(pos, os.SEEK_SET) # Seek back + + self.channel.settimeout(self.socket_timeout) + self.channel.exec_command( + self.scp_command + b" -t " + asbytes(quote(asunicode(remote_path))) + ) + self._recv_confirm() + self._send_file(fl, remote_path, mode, size=size) + self.close() + + def getfo(self, remote_path: str, fl: IO): + remote_path_sanitized = quote(remote_path) + if os.name == "nt": + remote_file_name = ntpath.basename(remote_path_sanitized) + else: + remote_file_name = os.path.basename(remote_path_sanitized) + self.channel.settimeout(self.socket_timeout) + self.channel.exec_command(self.scp_command + b" -f " + asbytes(remote_path_sanitized)) + self._recv_all(fl, remote_file_name) + self.close() + return fl + + def close(self): + """close scp channel""" + if self._channel is not None: + self._channel.close() + self._channel = None + + def _send_file(self, fl, name, mode, size): + basename = asbytes(os.path.basename(name)) + # The protocol can't handle \n in the filename. + # Quote them as the control sequence \^J for now, + # which is how openssh handles it. + self.channel.sendall( + ("C%s %d " % (mode, size)).encode("ascii") + basename.replace(b"\n", b"\\^J") + b"\n" + ) + self._recv_confirm() + file_pos = 0 + buff_size = self.buff_size + chan = self.channel + while file_pos < size: + chan.sendall(fl.read(buff_size)) + file_pos = fl.tell() + chan.sendall(b"\x00") + self._recv_confirm() + + def _recv_confirm(self): + # read scp response + msg = b"" + try: + msg = self.channel.recv(512) + except SocketTimeoutError: + raise SCPException("Timeout waiting for scp response") + # slice off the first byte, so this compare will work in py2 and py3 + if msg and msg[0:1] == b"\x00": + return + elif msg and msg[0:1] == b"\x01": + raise SCPException(asunicode(msg[1:])) + elif self.channel.recv_stderr_ready(): + msg = self.channel.recv_stderr(512) + raise SCPException(asunicode(msg)) + elif not msg: + raise SCPException("No response from server") + else: + raise SCPException("Invalid response from server", msg) + + def _recv_all(self, fh: IO, remote_file_name: str) -> None: + # loop over scp commands, and receive as necessary + commands = (b"C",) + while not self.channel.closed: + # wait for command as long as we're open + self.channel.sendall(b"\x00") + msg = self.channel.recv(1024) + if not msg: # chan closed while receiving + break + assert msg[-1:] == b"\n" + msg = msg[:-1] + code = msg[0:1] + if code not in commands: + raise SCPException(asunicode(msg[1:])) + self._recv_file(msg[1:], fh, remote_file_name) + + def _recv_file(self, cmd: bytes, fh: IO, remote_file_name: str) -> None: + chan = self.channel + parts = cmd.strip().split(b" ", 2) + + try: + size = int(parts[1]) + except (ValueError, IndexError): + chan.send(b"\x01") + chan.close() + raise SCPException("Bad file format") + + buff_size = self.buff_size + pos = 0 + chan.send(b"\x00") + try: + while pos < size: + # we have to make sure we don't read the final byte + if size - pos <= buff_size: + buff_size = size - pos + data = chan.recv(buff_size) + if not data: + raise SCPException("Underlying channel was closed") + fh.write(data) + pos = fh.tell() + msg = chan.recv(512) + if msg and msg[0:1] != b"\x00": + raise SCPException(asunicode(msg[1:])) + except SocketTimeoutError: + chan.close() + raise SCPException("Error receiving, socket.timeout") + + +class SCPException(Exception): + """SCP exception class""" + + pass diff --git a/pyinfra/connectors/ssh.py b/pyinfra/connectors/ssh.py index ff3c1f19d..43d015191 100644 --- a/pyinfra/connectors/ssh.py +++ b/pyinfra/connectors/ssh.py @@ -5,7 +5,7 @@ from shutil import which from socket import error as socket_error, gaierror from time import sleep -from typing import TYPE_CHECKING, Any, Iterable, Optional, Tuple +from typing import IO, TYPE_CHECKING, Any, Iterable, Optional, Protocol, Tuple import click from paramiko import AuthenticationException, BadHostKeyException, SFTPClient, SSHException @@ -17,6 +17,7 @@ from pyinfra.api.util import get_file_io, memoize from .base import BaseConnector, DataMeta +from .scp import SCPClient from .ssh_util import get_private_key, raise_connect_error from .sshuserclient import SSHClient from .util import ( @@ -53,6 +54,7 @@ class ConnectorData(TypedDict): ssh_connect_retries: int ssh_connect_retry_min_delay: float ssh_connect_retry_max_delay: float + ssh_file_transfer_protocol: str connector_data_meta: dict[str, DataMeta] = { @@ -92,9 +94,27 @@ class ConnectorData(TypedDict): "Upper bound for random delay between retries", 0.5, ), + "ssh_file_transfer_protocol": DataMeta( + "Protocol to use for file transfers. Can be ``sftp`` or ``scp``.", + "sftp", + ), } +class FileTransferClient(Protocol): + def getfo(self, remote_filename: str, fl: IO) -> Any | None: + """ + Get a file from the remote host, writing to the provided file-like object. + """ + ... + + def putfo(self, fl: IO, remote_filename: str) -> Any | None: + """ + Put a file to the remote host, reading from the provided file-like object. + """ + ... + + class SSHConnector(BaseConnector): """ Connect to hosts over SSH. This is the default connector and all targets default @@ -268,7 +288,7 @@ def _connect(self) -> None: @override def disconnect(self) -> None: - self.get_sftp_connection.cache.clear() + self.get_file_transfer_connection.cache.clear() @override def run_shell_command( @@ -353,13 +373,25 @@ def execute_command() -> Tuple[int, CommandOutput]: return status, combined_output @memoize - def get_sftp_connection(self): + def get_file_transfer_connection(self) -> FileTransferClient | None: assert self.client is not None transport = self.client.get_transport() assert transport is not None, "No transport" try: - return SFTPClient.from_transport(transport) + if self.data["ssh_file_transfer_protocol"] == "sftp": + logger.debug("Using SFTP for file transfer") + return SFTPClient.from_transport(transport) + elif self.data["ssh_file_transfer_protocol"] == "scp": + logger.debug("Using SCP for file transfer") + return SCPClient(transport) + else: + raise ConnectError( + "Unsupported file transfer protocol: {0}".format( + self.data["ssh_file_transfer_protocol"], + ), + ) except SSHException as e: + raise ConnectError( ( "Unable to establish SFTP connection. Check that the SFTP subsystem " @@ -367,9 +399,9 @@ def get_sftp_connection(self): ).format(self.host), ) from e - def _get_file(self, remote_filename: str, filename_or_io): + def _get_file(self, remote_filename: str, filename_or_io: str | IO): with get_file_io(filename_or_io, "wb") as file_io: - sftp = self.get_sftp_connection() + sftp = self.get_file_transfer_connection() sftp.getfo(remote_filename, file_io) @override @@ -448,7 +480,7 @@ def _put_file(self, filename_or_io, remote_location): while attempts < 3: try: with get_file_io(filename_or_io) as file_io: - sftp = self.get_sftp_connection() + sftp = self.get_file_transfer_connection() sftp.putfo(file_io, remote_location) return except OSError as e: diff --git a/pyinfra/connectors/util.py b/pyinfra/connectors/util.py index 3708e1499..d04b00db2 100644 --- a/pyinfra/connectors/util.py +++ b/pyinfra/connectors/util.py @@ -22,17 +22,17 @@ SUDO_ASKPASS_ENV_VAR = "PYINFRA_SUDO_PASSWORD" + + SUDO_ASKPASS_COMMAND = r""" -temp=$(mktemp "${{TMPDIR:=/tmp}}/pyinfra-sudo-askpass-XXXXXXXXXXXX") +temp=$(mktemp "${{TMPDIR:={0}}}/pyinfra-sudo-askpass-XXXXXXXXXXXX") cat >"$temp"<<'__EOF__' #!/bin/sh printf '%s\n' "${0}" __EOF__ chmod 755 "$temp" echo "$temp" -""".format( - SUDO_ASKPASS_ENV_VAR, -) +""" def run_local_process( @@ -264,7 +264,9 @@ def extract_control_arguments(arguments: "ConnectorArguments") -> "ConnectorArgu def _ensure_sudo_askpass_set_for_host(host: "Host"): if host.connector_data.get("sudo_askpass_path"): return - _, output = host.run_shell_command(SUDO_ASKPASS_COMMAND) + _, output = host.run_shell_command( + SUDO_ASKPASS_COMMAND.format(host.get_temp_dir_config(), SUDO_ASKPASS_ENV_VAR) + ) host.connector_data["sudo_askpass_path"] = shlex.quote(output.stdout_lines[0]) @@ -314,6 +316,10 @@ def make_unix_command( # Doas config _doas=False, _doas_user=None, + # Retry config (ignored in command generation but passed through) + _retries=0, + _retry_delay=0, + _retry_until=None, ) -> StringCommand: """ Builds a shell command with various kwargs. diff --git a/pyinfra/facts/dnf.py b/pyinfra/facts/dnf.py index 5f80727b0..f20e5dbda 100644 --- a/pyinfra/facts/dnf.py +++ b/pyinfra/facts/dnf.py @@ -16,11 +16,15 @@ class DnfRepositories(FactBase): [ { - "name": "CentOS-$releasever - Apps", - "baseurl": "http://mirror.centos.org/$contentdir/$releasever/Apps/$basearch/os/", - "gpgcheck": "1", + "repoid": "baseos", + "name": "AlmaLinux $releasever - BaseOS", + "mirrorlist": "https://mirrors.almalinux.org/mirrorlist/$releasever/baseos", "enabled": "1", - "gpgkey": "file:///etc/pki/rpm-gpg/RPM-GPG-KEY-centosofficial", + "gpgcheck": "1", + "countme": "1", + "gpgkey": "file:///etc/pki/rpm-gpg/RPM-GPG-KEY-AlmaLinux-9", + "metadata_expire": "86400", + "enabled_metadata": "1" }, ] """ diff --git a/pyinfra/facts/docker.py b/pyinfra/facts/docker.py index 14e925036..5a74f24a5 100644 --- a/pyinfra/facts/docker.py +++ b/pyinfra/facts/docker.py @@ -76,6 +76,28 @@ def command(self) -> str: return "docker network inspect `docker network ls -q`" +class DockerVolumes(DockerFactBase): + """ + Returns ``docker inspect`` output for all Docker volumes. + """ + + @override + def command(self) -> str: + return "docker volume inspect `docker volume ls -q`" + + +class DockerPlugins(DockerFactBase): + """ + Returns ``docker plugin inspect`` output for all Docker plugins. + """ + + @override + def command(self) -> str: + return """ + ids=$(docker plugin ls -q) && [ -n "$ids" ] && docker plugin inspect $ids || echo "[]" + """.strip() + + # Single Docker objects # @@ -113,19 +135,17 @@ class DockerNetwork(DockerSingleMixin): docker_type = "network" -class DockerVolumes(DockerFactBase): +class DockerVolume(DockerSingleMixin): """ - Returns ``docker inspect`` output for all Docker volumes. + Returns ``docker inspect`` output for a single Docker container. """ - @override - def command(self) -> str: - return "docker volume inspect `docker volume ls -q`" + docker_type = "volume" -class DockerVolume(DockerSingleMixin): +class DockerPlugin(DockerSingleMixin): """ - Returns ``docker inspect`` output for a single Docker container. + Returns ``docker plugin inspect`` output for a single Docker plugin. """ - docker_type = "volume" + docker_type = "plugin" diff --git a/pyinfra/facts/files.py b/pyinfra/facts/files.py index 3a32b1f12..f58648c7a 100644 --- a/pyinfra/facts/files.py +++ b/pyinfra/facts/files.py @@ -643,10 +643,13 @@ def command(self, path, marker=None, begin=None, end=None): QuoteString(f"{EXISTS}{path}"), QuoteString(f"{MISSING}{path}"), ) - # m_f_s_c inserts blanks in unfortunate places, e.g. after first slash - cmd = make_formatted_string_command( - f"awk \\'/{end}/{{{{f=0}}}} f; /{start}/{{{{f=1}}}}\\' {{0}} || {backstop}", + + cmd = StringCommand( + f"awk '/{end}/{{ f=0}} f; /{start}/{{ f=1}} ' ", QuoteString(path), + " || ", + backstop, + _separator="", ) return cmd @@ -657,3 +660,18 @@ def process(self, output): if output and (output[0] == f"{MISSING}{self.path}"): return None return output + + +class FileContents(FactBase): + """ + Returns the contents of a file as a list of lines. Works with both sha1sum and sha1. Returns + ``None`` if the file doest not exist. + """ + + @override + def command(self, path): + return make_formatted_string_command("cat {0}", QuoteString(path)) + + @override + def process(self, output): + return output diff --git a/pyinfra/facts/server.py b/pyinfra/facts/server.py index 639ad5851..7ff7c86cd 100644 --- a/pyinfra/facts/server.py +++ b/pyinfra/facts/server.py @@ -661,7 +661,9 @@ def process(self, output) -> LinuxDistributionDict: for filename, content in parts.items(): with open( - os.path.join(temp_etc_dir, os.path.basename(filename)), "w", encoding="utf-8" + os.path.join(temp_etc_dir, os.path.basename(filename)), + "w", + encoding="utf-8", ) as fp: fp.write(content) @@ -901,3 +903,52 @@ def process(self, output): ) return limits + + +class RebootRequired(FactBase[bool]): + """ + Returns a boolean indicating whether the system requires a reboot. + + On Linux systems: + - Checks /var/run/reboot-required and /var/run/reboot-required.pkgs + - On Alpine Linux, compares installed kernel with running kernel + + On FreeBSD systems: + - Compares running kernel version with installed kernel version + """ + + @override + def command(self) -> str: + return """ +# Get OS type +OS_TYPE=$(uname -s) +if [ "$OS_TYPE" = "Linux" ]; then + # Check if it's Alpine Linux + if [ -f /etc/alpine-release ]; then + RUNNING_KERNEL=$(uname -r) + INSTALLED_KERNEL=$(ls -1 /lib/modules | sort -V | tail -n1) + if [ "$RUNNING_KERNEL" != "$INSTALLED_KERNEL" ]; then + echo "reboot_required" + exit 0 + fi + else + # Check standard Linux reboot required files + if [ -f /var/run/reboot-required ] || [ -f /var/run/reboot-required.pkgs ]; then + echo "reboot_required" + exit 0 + fi + fi +elif [ "$OS_TYPE" = "FreeBSD" ]; then + RUNNING_VERSION=$(freebsd-version -r) + INSTALLED_VERSION=$(freebsd-version -k) + if [ "$RUNNING_VERSION" != "$INSTALLED_VERSION" ]; then + echo "reboot_required" + exit 0 + fi +fi +echo "no_reboot_required" +""" + + @override + def process(self, output) -> bool: + return list(output)[0].strip() == "reboot_required" diff --git a/pyinfra/facts/util/packaging.py b/pyinfra/facts/util/packaging.py index 22442823e..017d23cf2 100644 --- a/pyinfra/facts/util/packaging.py +++ b/pyinfra/facts/util/packaging.py @@ -32,6 +32,7 @@ def _parse_yum_or_zypper_repositories(output): repos.append(current_repo) current_repo = {} + current_repo["repoid"] = line[1:-1] current_repo["name"] = line[1:-1] if current_repo and "=" in line: diff --git a/pyinfra/facts/yum.py b/pyinfra/facts/yum.py index 437d41d65..caf792365 100644 --- a/pyinfra/facts/yum.py +++ b/pyinfra/facts/yum.py @@ -16,11 +16,15 @@ class YumRepositories(FactBase): [ { - "name": "CentOS-$releasever - Apps", - "baseurl": "http://mirror.centos.org/$contentdir/$releasever/Apps/$basearch/os/", - "gpgcheck": "1", + "repoid": "baseos", + "name": "AlmaLinux $releasever - BaseOS", + "mirrorlist": "https://mirrors.almalinux.org/mirrorlist/$releasever/baseos", "enabled": "1", - "gpgkey": "file:///etc/pki/rpm-gpg/RPM-GPG-KEY-centosofficial", + "gpgcheck": "1", + "countme": "1", + "gpgkey": "file:///etc/pki/rpm-gpg/RPM-GPG-KEY-AlmaLinux-9", + "metadata_expire": "86400", + "enabled_metadata": "1" }, ] """ diff --git a/pyinfra/facts/zypper.py b/pyinfra/facts/zypper.py index aef4da801..55963f97e 100644 --- a/pyinfra/facts/zypper.py +++ b/pyinfra/facts/zypper.py @@ -16,11 +16,11 @@ class ZypperRepositories(FactBase): [ { + "repoid": "repo-oss", "name": "Main Repository", "enabled": "1", - "autorefresh": "0", - "baseurl": "http://download.opensuse.org/distribution/leap/$releasever/repo/oss/", - "type": "rpm-md", + "autorefresh": "1", + "baseurl": "http://download.opensuse.org/distribution/leap/$releasever/repo/oss/" }, ] """ diff --git a/pyinfra/operations/docker.py b/pyinfra/operations/docker.py index 6320f814c..d29d53739 100644 --- a/pyinfra/operations/docker.py +++ b/pyinfra/operations/docker.py @@ -4,25 +4,27 @@ as inventory directly. """ +from __future__ import annotations + from pyinfra import host from pyinfra.api import operation -from pyinfra.facts.docker import DockerContainer, DockerNetwork, DockerVolume +from pyinfra.facts.docker import DockerContainer, DockerNetwork, DockerPlugin, DockerVolume from .util.docker import ContainerSpec, handle_docker @operation() def container( - container, - image="", - ports=None, - networks=None, - volumes=None, - env_vars=None, - pull_always=False, - present=True, - force=False, - start=True, + container: str, + image: str = "", + ports: list[str] | None = None, + networks: list[str] | None = None, + volumes: list[str] | None = None, + env_vars: list[str] | None = None, + pull_always: bool = False, + present: bool = True, + force: bool = False, + start: bool = True, ): """ Manage Docker containers @@ -168,7 +170,7 @@ def image(image, present=True): @operation() -def volume(volume, driver="", labels=None, present=True): +def volume(volume: str, driver: str = "", labels: list[str] | None = None, present: bool = True): """ Manage Docker volumes @@ -220,20 +222,20 @@ def volume(volume, driver="", labels=None, present=True): @operation() def network( - network, - driver="", - gateway="", - ip_range="", - ipam_driver="", - subnet="", - scope="", - aux_addresses=None, - opts=None, - ipam_opts=None, - labels=None, - ingress=False, - attachable=False, - present=True, + network: str, + driver: str = "", + gateway: str = "", + ip_range: str = "", + ipam_driver: str = "", + subnet: str = "", + scope: str = "", + aux_addresses: dict[str, str] | None = None, + opts: list[str] | None = None, + ipam_opts: list[str] | None = None, + labels: list[str] | None = None, + ingress: bool = False, + attachable: bool = False, + present: bool = True, ): """ Manage docker networks @@ -245,6 +247,7 @@ def network( + ipam_driver: IP Address Management Driver + subnet: Subnet in CIDR format that represents a network segment + scope: Control the network's scope + + aux_addresses: named aux addresses for the network + opts: Set driver specific options + ipam_opts: Set IPAM driver specific options + labels: Label list to attach in the network @@ -303,9 +306,9 @@ def network( @operation(is_idempotent=False) def prune( - all=False, - volumes=False, - filter="", + all: bool = False, + volumes: bool = False, + filter: str = "", ): """ Execute a docker system prune. @@ -344,3 +347,101 @@ def prune( volumes=volumes, filter=filter, ) + + +@operation() +def plugin( + plugin: str, + alias: str | None = None, + present: bool = True, + enabled: bool = True, + plugin_options: dict[str, str] | None = None, +): + """ + Manage Docker plugins + + + plugin: Plugin name + + alias: Alias for the plugin (optional) + + present: Whether the plugin should be installed + + enabled: Whether the plugin should be enabled + + plugin_options: Options to pass to the plugin + + **Examples:** + + .. code:: python + + # Install and enable a Docker plugin + docker.plugin( + name="Install and enable a Docker plugin", + plugin="username/my-awesome-plugin:latest", + alias="my-plugin", + present=True, + enabled=True, + plugin_options={"option1": "value1", "option2": "value2"}, + ) + """ + plugin_name = alias if alias else plugin + existent_plugin = host.get_fact(DockerPlugin, object_id=plugin_name) + if existent_plugin: + existent_plugin = existent_plugin[0] + + if present: + if existent_plugin: + plugin_options_different = ( + plugin_options and existent_plugin["Settings"]["Env"] != plugin_options + ) + if plugin_options_different: + # Update options on existing plugin + if existent_plugin["Enabled"]: + yield handle_docker( + resource="plugin", + command="disable", + plugin=plugin_name, + ) + yield handle_docker( + resource="plugin", + command="set", + plugin=plugin_name, + enabled=enabled, + existent_options=existent_plugin["Settings"]["Env"], + required_options=plugin_options, + ) + if enabled: + yield handle_docker( + resource="plugin", + command="enable", + plugin=plugin_name, + ) + else: + # Options are the same, check if enabled state is different + if existent_plugin["Enabled"] == enabled: + host.noop( + f"Plugin '{plugin_name}' is already installed with the same options " + f"and {'enabled' if enabled else 'disabled'}." + ) + return + else: + command = "enable" if enabled else "disable" + yield handle_docker( + resource="plugin", + command=command, + plugin=plugin_name, + ) + else: + yield handle_docker( + resource="plugin", + command="install", + plugin=plugin, + alias=alias, + enabled=enabled, + plugin_options=plugin_options, + ) + else: + if not existent_plugin: + host.noop(f"Plugin '{plugin_name}' is not installed.") + return + yield handle_docker( + resource="plugin", + command="remove", + plugin=plugin_name, + ) diff --git a/pyinfra/operations/files.py b/pyinfra/operations/files.py index b17c55626..b9c1c017f 100644 --- a/pyinfra/operations/files.py +++ b/pyinfra/operations/files.py @@ -14,6 +14,7 @@ from pathlib import Path from typing import IO, Any, Union +import click from jinja2 import TemplateRuntimeError, TemplateSyntaxError, UndefinedError from pyinfra import host, logger, state @@ -46,6 +47,7 @@ Block, Directory, File, + FileContents, FindFiles, FindInFile, Flags, @@ -62,6 +64,7 @@ MetadataTimeField, adjust_regex, ensure_mode_int, + generate_color_diff, get_timestamp, sed_delete, sed_replace, @@ -1030,6 +1033,23 @@ def put( # File exists, check sum and check user/group/mode/atime/mtime if supplied else: if not _file_equal(local_sum_path, dest): + if state.config.DIFF: + # Generate diff when contents change + current_contents = host.get_fact(FileContents, path=dest) + if current_contents: + current_lines = [line + "\n" for line in current_contents] + else: + current_lines = [] + + logger.info(f"\n Will modify {click.style(dest, bold=True)}") + + with get_file_io(src, "r") as f: + desired_lines = f.readlines() + + for line in generate_color_diff(current_lines, desired_lines): + logger.info(f" {line}") + logger.info("") + yield FileUploadCommand( local_file, dest, @@ -1718,7 +1738,7 @@ def block( path="/etc/hosts", content="10.0.0.1 mars-one", before=True, - regex=".*localhost", + line=".*localhost", ) # have two entries in /etc/host @@ -1727,7 +1747,7 @@ def block( path="/etc/hosts", content="10.0.0.1 mars-one\\n10.0.0.2 mars-two", before=True, - regex=".*localhost", + line=".*localhost", ) # remove marked entry from /etc/hosts @@ -1742,7 +1762,7 @@ def block( name="add out of date warning to web page", path="/var/www/html/something.html", content= "

Warning: this page is out of date.

", - regex=".*.*", + line=".*.*", after=True marker="", ) @@ -1764,7 +1784,8 @@ def block( # standard awk doesn't have an "in-place edit" option so we write to a tempfile and # if edits were successful move to dest i.e. we do: ... do some work ... q_path = QuoteString(path) - out_prep = StringCommand('OUT="$(TMPDIR=/tmp mktemp -t pyinfra.XXXXXX)" && ') + tmp_dir = host.get_temp_dir_config() + out_prep = StringCommand(f'OUT="$(TMPDIR:={tmp_dir} mktemp -t pyinfra.XXXXXX)" && ') if backup: out_prep = StringCommand( "cp", @@ -1789,6 +1810,7 @@ def block( ) current = host.get_fact(Block, path=path, marker=marker, begin=begin, end=end) + # None means file didn't exist, empty list means marker was not found cmd = None if present: if not content: diff --git a/pyinfra/operations/git.py b/pyinfra/operations/git.py index 6e53506f4..905d3ae12 100644 --- a/pyinfra/operations/git.py +++ b/pyinfra/operations/git.py @@ -148,7 +148,7 @@ def repo( if branch and host.get_fact(GitBranch, repo=dest) != branch: git_commands.append("fetch") # fetch to ensure we have the branch locally git_commands.append("checkout {0}".format(branch)) - if branch and branch in host.get_fact(GitTag, repo=dest): + if branch and branch in (host.get_fact(GitTag, repo=dest) or []): git_commands.append("checkout {0}".format(branch)) is_tag = True if pull and not is_tag: diff --git a/pyinfra/operations/pip.py b/pyinfra/operations/pip.py index 1704daf74..fe2cb8a9c 100644 --- a/pyinfra/operations/pip.py +++ b/pyinfra/operations/pip.py @@ -11,7 +11,7 @@ from pyinfra.facts.pip import PipPackages from . import files -from .util.packaging import ensure_packages +from .util.packaging import PkgInfo, ensure_packages @operation() @@ -186,21 +186,20 @@ def packages( # Handle passed in packages if packages: + if isinstance(packages, str): + packages = [packages] + # PEP-0426 states that Python packages should be compared using lowercase, so lowercase the + # current packages. PkgInfo.from_pep508 takes care of the package name current_packages = host.get_fact(PipPackages, pip=pip) - - # PEP-0426 states that Python packages should be compared using lowercase, so lowercase both - # the input packages and the fact packages before comparison. - packages = [pkg.lower() for pkg in packages] current_packages = {pkg.lower(): versions for pkg, versions in current_packages.items()} yield from ensure_packages( host, - packages, + list(filter(None, (PkgInfo.from_pep508(package) for package in packages))), current_packages, present, install_command=install_command, uninstall_command=uninstall_command, upgrade_command=upgrade_command, - version_join="==", latest=latest, ) diff --git a/pyinfra/operations/pipx.py b/pyinfra/operations/pipx.py index 0b1e474fd..c4927e03e 100644 --- a/pyinfra/operations/pipx.py +++ b/pyinfra/operations/pipx.py @@ -2,25 +2,27 @@ Manage pipx (python) applications. """ +from typing import Optional, Union + from pyinfra import host from pyinfra.api import operation from pyinfra.facts.pipx import PipxEnvironment, PipxPackages from pyinfra.facts.server import Path -from .util.packaging import ensure_packages +from .util.packaging import PkgInfo, ensure_packages @operation() def packages( - packages=None, + packages: Optional[Union[str, list[str]]] = None, present=True, latest=False, - extra_args=None, + extra_args: Optional[str] = None, ): """ Install/remove/update pipx packages. - + packages: list of packages to ensure + + packages: list of packages (PEP-508 format) to ensure + present: whether the packages should be installed + latest: whether to upgrade packages without a specified version + extra_args: additional arguments to the pipx command @@ -37,6 +39,9 @@ def packages( packages=["pyinfra"], ) """ + if packages is None: + host.noop("no package list provided to pipx.packages") + return prep_install_command = ["pipx", "install"] @@ -47,19 +52,26 @@ def packages( uninstall_command = "pipx uninstall" upgrade_command = "pipx upgrade" - current_packages = host.get_fact(PipxPackages) + # PEP-0426 states that Python packages should be compared using lowercase, so lowercase the + # current packages. PkgInfo.from_pep508 takes care of it for the package names + current_packages = { + pkg.lower(): version for pkg, version in host.get_fact(PipxPackages).items() + } + if isinstance(packages, str): + packages = [packages] # pipx support only one package name at a time for package in packages: + if (pkg_info := PkgInfo.from_pep508(package)) is None: + continue # from_pep508 logged a warning yield from ensure_packages( host, - [package], + [pkg_info], current_packages, present, install_command=install_command, uninstall_command=uninstall_command, upgrade_command=upgrade_command, - version_join="==", latest=latest, ) diff --git a/pyinfra/operations/util/docker.py b/pyinfra/operations/util/docker.py index 43b289c78..c362ff03d 100644 --- a/pyinfra/operations/util/docker.py +++ b/pyinfra/operations/util/docker.py @@ -165,7 +165,46 @@ def _remove_network(**kwargs): return "docker network rm {0}".format(kwargs["network"]) -def handle_docker(resource, command, **kwargs): +def _install_plugin(**kwargs): + command = ["docker plugin install {0} --grant-all-permissions".format(kwargs["plugin"])] + + plugin_options = kwargs["plugin_options"] if kwargs["plugin_options"] else {} + + if kwargs["alias"]: + command.append("--alias {0}".format(kwargs["alias"])) + + if not kwargs["enabled"]: + command.append("--disable") + + for option, value in plugin_options.items(): + command.append("{0}={1}".format(option, value)) + + return " ".join(command) + + +def _remove_plugin(**kwargs): + return "docker plugin rm -f {0}".format(kwargs["plugin"]) + + +def _enable_plugin(**kwargs): + return "docker plugin enable {0}".format(kwargs["plugin"]) + + +def _disable_plugin(**kwargs): + return "docker plugin disable {0}".format(kwargs["plugin"]) + + +def _set_plugin_options(**kwargs): + command = ["docker plugin set {0}".format(kwargs["plugin"])] + existent_options = kwargs.get("existing_options", {}) + required_options = kwargs.get("required_options", {}) + options_to_set = existent_options | required_options + for option, value in options_to_set.items(): + command.append("{0}={1}".format(option, value)) + return " ".join(command) + + +def handle_docker(resource: str, command: str, **kwargs): container_commands = { "create": _create_container, "remove": _remove_container, @@ -192,12 +231,21 @@ def handle_docker(resource, command, **kwargs): "prune": _prune_command, } + plugin_commands = { + "install": _install_plugin, + "remove": _remove_plugin, + "enable": _enable_plugin, + "disable": _disable_plugin, + "set": _set_plugin_options, + } + docker_commands = { "container": container_commands, "image": image_commands, "volume": volume_commands, "network": network_commands, "system": system_commands, + "plugin": plugin_commands, } return docker_commands[resource][command](**kwargs) diff --git a/pyinfra/operations/util/files.py b/pyinfra/operations/util/files.py index bc63d5043..24d37bafa 100644 --- a/pyinfra/operations/util/files.py +++ b/pyinfra/operations/util/files.py @@ -1,9 +1,12 @@ from __future__ import annotations +import difflib import re from datetime import datetime, timezone from enum import Enum -from typing import Callable +from typing import Callable, Generator + +import click from pyinfra.api import QuoteString, StringCommand @@ -207,3 +210,34 @@ def adjust_regex(line: str, escape_regex_characters: bool) -> str: match_line = "{0}.*$".format(match_line) return match_line + + +def generate_color_diff( + current_lines: list[str], desired_lines: list[str] +) -> Generator[str, None, None]: + def _format_range_unified(start: int, stop: int) -> str: + beginning = start + 1 # lines start numbering with one + length = stop - start + if length == 1: + return "{}".format(beginning) + if not length: + beginning -= 1 # empty ranges begin at line just before the range + return "{},{}".format(beginning, length) + + for group in difflib.SequenceMatcher(None, current_lines, desired_lines).get_grouped_opcodes(2): + first, last = group[0], group[-1] + file1_range = _format_range_unified(first[1], last[2]) + file2_range = _format_range_unified(first[3], last[4]) + yield "@@ -{} +{} @@".format(file1_range, file2_range) + + for tag, i1, i2, j1, j2 in group: + if tag == "equal": + for line in current_lines[i1:i2]: + yield " " + line.rstrip() + continue + if tag in {"replace", "delete"}: + for line in current_lines[i1:i2]: + yield click.style("- " + line.rstrip(), "red") + if tag in {"replace", "insert"}: + for line in desired_lines[j1:j2]: + yield click.style("+ " + line.rstrip(), "green") diff --git a/pyinfra/operations/util/packaging.py b/pyinfra/operations/util/packaging.py index eb94e7ca0..e5c56374f 100644 --- a/pyinfra/operations/util/packaging.py +++ b/pyinfra/operations/util/packaging.py @@ -3,19 +3,83 @@ import shlex from collections import defaultdict from io import StringIO -from typing import Callable +from typing import Callable, NamedTuple, cast from urllib.parse import urlparse -from pyinfra.api import Host, State +from packaging.requirements import InvalidRequirement, Requirement + +from pyinfra import logger +from pyinfra.api import Host, OperationValueError, State from pyinfra.facts.files import File from pyinfra.facts.rpm import RpmPackage from pyinfra.operations import files -def _package_name(package: list[str] | str) -> str: - if isinstance(package, list): - return package[0] - return package +class PkgInfo(NamedTuple): + name: str + version: str + operator: str + url: str + """ + The key packaging information needed: version, operator and url are optional. + """ + + @property + def lkup_name(self) -> str | list[str]: + return self.name if self.version == "" else [self.name, self.version] + + @property + def has_version(self) -> bool: + return self.version != "" + + @property + def inst_vers(self) -> str: + return ( + self.url + if self.url != "" + else ( + self.operator.join([self.name, self.version]) if self.version != "" else self.name + ) + ) + + @classmethod + def from_possible_pair(cls, s: str, join: str | None) -> PkgInfo: + if join is not None: + pieces = s.rsplit(join, 1) + return cls(pieces[0], pieces[1] if len(pieces) > 1 else "", join, "") + + return cls(s, "", "", "") + + @classmethod + def from_pep508(cls, s: str) -> PkgInfo | None: + """ + Separate out the useful parts (name, url, operator, version) of a PEP-508 dependency. + Note: only one specifier is allowed. + PEP-0426 states that Python packages should be compared using lowercase; thus + the name is lower-cased + For backwards compatibility, invalid requirements are assumed to be package names with a + warning that this will change in the next major release + """ + pep_508 = "PEP 508 non-compliant " + treatment = "requirement treated as package name" + will_change = "4.x will make this an error" # pip and pipx already throw away None's + try: + reqt = Requirement(s) + except InvalidRequirement as e: + logger.warning(f"{pep_508} :{e}\n{will_change}") + return cls(s, "", "", "") + else: + if (len(reqt.specifier) > 0) and (len(reqt.specifier) > 1): + logger.warning(f"{pep_508}/unsupported specifier ({s}) {treatment}\n{will_change}") + return cls(s, "", "", "") + else: + spec = next(iter(reqt.specifier), None) + return cls( + reqt.name.lower(), + spec.version if spec is not None else "", + spec.operator if spec is not None else "", + reqt.url or "", + ) def _has_package( @@ -57,12 +121,12 @@ def in_packages(pkg_name, pkg_versions): def ensure_packages( host: Host, - packages_to_ensure: str | list[str] | None, + packages_to_ensure: str | list[str] | list[PkgInfo] | None, current_packages: dict[str, set[str]], present: bool, install_command: str, uninstall_command: str, - latest=False, + latest: bool = False, upgrade_command: str | None = None, version_join: str | None = None, expand_package_fact: Callable[[str], list[str | list[str]]] | None = None, @@ -70,22 +134,22 @@ def ensure_packages( """ Handles this common scenario: - + We have a list of packages(/versions) to ensure + + We have a list of packages(/versions/urls) to ensure + We have a map of existing package -> versions + We have the common command bits (install, uninstall, version "joiner") + Outputs commands to ensure our desired packages/versions + Optionally upgrades packages w/o specified version when present Args: - packages_to_ensure (list): list of packages or package/versions - current_packages (fact): fact returning dict of package names -> version + packages_to_ensure (list): list of packages or package/versions or PkgInfo's + current_packages (dict): dict of package names -> version present (bool): whether packages should exist or not install_command (str): command to prefix to list of packages to install uninstall_command (str): as above for uninstalling packages latest (bool): whether to upgrade installed packages when present upgrade_command (str): as above for upgrading version_join (str): the package manager specific "joiner", ie ``=`` for \ - ``=`` + ``=``. Not allowed if (pkg, ver, url) tuples are provided. expand_package_fact: fact returning packages providing a capability \ (ie ``yum whatprovides``) """ @@ -95,12 +159,15 @@ def ensure_packages( if isinstance(packages_to_ensure, str): packages_to_ensure = [packages_to_ensure] - packages: list[str | list[str]] = packages_to_ensure # type: ignore[assignment] - - if version_join: + packages: list[PkgInfo] = [] + if isinstance(packages_to_ensure[0], PkgInfo): + packages = cast("list[PkgInfo]", packages_to_ensure) + if version_join is not None: + raise OperationValueError("cannot specify version_join and provide list[PkgInfo]") + else: packages = [ - package[0] if len(package) == 1 else package - for package in [package.rsplit(version_join, 1) for package in packages] # type: ignore[union-attr] # noqa + PkgInfo.from_possible_pair(package, version_join) + for package in cast("list[str]", packages_to_ensure) ] diff_packages = [] @@ -111,65 +178,41 @@ def ensure_packages( if present is True: for package in packages: has_package, expanded_packages = _has_package( - package, - current_packages, - expand_package_fact, + package.lkup_name, current_packages, expand_package_fact ) if not has_package: - diff_packages.append(package) - diff_expanded_packages[_package_name(package)] = expanded_packages + diff_packages.append(package.inst_vers) + diff_expanded_packages[package.name] = expanded_packages else: # Present packages w/o version specified - for upgrade if latest - if isinstance(package, str): - upgrade_packages.append(package) + if not package.has_version: # don't try to upgrade if a specific version requested + upgrade_packages.append(package.inst_vers) if not latest: - pkg_name = _package_name(package) - if pkg_name in current_packages: - host.noop( - "package {0} is installed ({1})".format( - package, - ", ".join(current_packages[pkg_name]), - ), - ) + if (pkg := package.name) in current_packages: + host.noop(f"package {pkg} is installed ({','.join(current_packages[pkg])})") else: - host.noop("package {0} is installed".format(package)) + host.noop(f"package {package.name} is installed") if present is False: for package in packages: - # String version, just check if existing has_package, expanded_packages = _has_package( - package, - current_packages, - expand_package_fact, - match_any=True, + package.lkup_name, current_packages, expand_package_fact, match_any=True ) if has_package: - diff_packages.append(package) - diff_expanded_packages[_package_name(package)] = expanded_packages + diff_packages.append(package.inst_vers) + diff_expanded_packages[package.name] = expanded_packages else: - host.noop("package {0} is not installed".format(package)) + host.noop(f"package {package.name} is not installed") if diff_packages: command = install_command if present else uninstall_command - - joined_packages = [ - version_join.join(package) if isinstance(package, list) else package # type: ignore[union-attr] # noqa - for package in diff_packages - ] - - yield "{0} {1}".format( - command, - " ".join([shlex.quote(pkg) for pkg in joined_packages]), - ) + yield f"{command} {' '.join([shlex.quote(pkg) for pkg in diff_packages])}" if latest and upgrade_command and upgrade_packages: - yield "{0} {1}".format( - upgrade_command, - " ".join([shlex.quote(pkg) for pkg in upgrade_packages]), - ) + yield f"{upgrade_command} {' '.join([shlex.quote(pkg) for pkg in upgrade_packages])}" def ensure_rpm(state: State, host: Host, source: str, present: bool, package_manager_command: str): diff --git a/pyinfra_cli/main.py b/pyinfra_cli/main.py index 98cdb615e..a3e25e10e 100644 --- a/pyinfra_cli/main.py +++ b/pyinfra_cli/main.py @@ -67,6 +67,12 @@ def _print_support(ctx, param, value): default=False, help="Don't execute operations on the target hosts.", ) +@click.option( + "--diff", + is_flag=True, + default=False, + help="Show the differences when changing text files and templates.", +) @click.option( "-y", "--yes", @@ -132,6 +138,18 @@ def _print_support(ctx, param, value): default=False, help="Run operations in serial, host by host.", ) +@click.option( + "--retry", + type=int, + default=0, + help="Number of times to retry failed operations.", +) +@click.option( + "--retry-delay", + type=int, + default=5, + help="Delay in seconds between retry attempts.", +) # SSH connector args # TODO: remove the non-ssh-prefixed variants @click.option("--ssh-user", "--user", "ssh_user", help="SSH user to connect as.") @@ -267,10 +285,13 @@ def _main( group_data, config_filename: str, dry: bool, + diff: bool, yes: bool, limit: Iterable, no_wait: bool, serial: bool, + retry: int, + retry_delay: int, debug: bool, debug_all: bool, debug_facts: bool, @@ -310,6 +331,9 @@ def _main( shell_executable, fail_percent, yes, + diff, + retry, + retry_delay, ) override_data = _set_override_data( data, @@ -549,6 +573,9 @@ def _set_config( shell_executable, fail_percent, yes, + diff, + retry, + retry_delay, ): logger.info("--> Loading config...") @@ -583,6 +610,15 @@ def _set_config( if fail_percent is not None: config.FAIL_PERCENT = fail_percent + if diff: + config.DIFF = True + + if retry is not None: + config.RETRY = retry + + if retry_delay is not None: + config.RETRY_DELAY = retry_delay + return config @@ -709,10 +745,13 @@ def _run_fact_operations(state, config, operations): def _prepare_exec_operations(state, config, operations): state.print_output = True + # Pass the retry settings from config to the shell operation load_func( state, server.shell, " ".join(operations), + _retries=config.RETRY, + _retry_delay=config.RETRY_DELAY, ) return state diff --git a/pyproject.toml b/pyproject.toml index 93ac44677..24ec0a015 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,3 +36,43 @@ ignore_missing_imports = true enable_incomplete_feature = "Unpack" enable_error_code = "explicit-override" files = "pyinfra,pyinfra_cli" + +[tool.ruff] +line-length = 100 +target-version = "py312" +output-format="concise" + +[tool.ruff.format] +docstring-code-format = true +indent-style = "space" +quote-style = "double" +skip-magic-trailing-comma = false +line-ending = "auto" + +[tool.ruff.lint] +exclude=["tests/**"] +preview = true +select = ["ALL"] + +# ignores are because: +# D, TD, DOC - not clean enough to worry about docstrings +# PLR, C - will deal with complexity related stuff later +# ISC - isort rules can conflict with ruff formatter +# CPY - need to figure out standard copyright notice +# COM812 - not clear why I need the trailing comma n in a bunch of cases +# E731 - I like to use lambdas +# ERA001 - can clean up leftovers later +# G004 - easier to use f-strings in logs and just a POC so perf isn't the issue + +# TODO - turn these back on once evertyhing else done +# TRY003 - I prefer to provide details +# EM101, EM102 - I like putting strings in my exceptions when I raise them +# PLW0108 - too many apply's with special_sum using them # TODO -see if I can just supply the callabe +# TODO - figure out how to be able to use .loc in quote/compare.py +# PD008 - I can't see the harm in being specific and need to use it in compare +# PLC1901 - I like to be specific in the compare +# TID252 - I like relative imports within a module +# RUF100 - don't complain about unused noqa until we've turned everything back on.. +ignore =["D", "TD", "DOC", "FIX", "CPY", "ISC", "PLR", "C", + "E731", 'D211', "D213", "COM812", "TRY003", "PD008", "PLW0108", "EM101", "EM102", + "PLC1901", "TID252", "G004", "ERA001", "RUF100"] diff --git a/tests/facts/dnf.DnfRepositories/repos-with-spaces.json b/tests/facts/dnf.DnfRepositories/repos-with-spaces.json index c0d0ae27f..cc556dadc 100644 --- a/tests/facts/dnf.DnfRepositories/repos-with-spaces.json +++ b/tests/facts/dnf.DnfRepositories/repos-with-spaces.json @@ -7,6 +7,7 @@ ], "fact": [ { + "repoid": "rhel-atomic-7-cdk-3.6-source-rpms", "name": "Red Hat Container Development Kit 3.6 /(Source RPMs)" } ] diff --git a/tests/facts/dnf.DnfRepositories/repos.json b/tests/facts/dnf.DnfRepositories/repos.json index 6741299f2..5bac4cb44 100644 --- a/tests/facts/dnf.DnfRepositories/repos.json +++ b/tests/facts/dnf.DnfRepositories/repos.json @@ -10,10 +10,12 @@ ], "fact": [ { + "repoid": "somerepo", "name": "somerepo", "baseurl": "abc" }, { + "repoid": "anotherrepo", "name": "anotherrepo" } ] diff --git a/tests/facts/docker.DockerPlugin/plugin.json b/tests/facts/docker.DockerPlugin/plugin.json new file mode 100644 index 000000000..d63d43a35 --- /dev/null +++ b/tests/facts/docker.DockerPlugin/plugin.json @@ -0,0 +1,11 @@ +{ + "arg": "myid", + "command": "docker plugin inspect myid 2>&- || true", + "requires_command": "docker", + "output": [ + "{\"hello\": \"world\"}" + ], + "fact": { + "hello": "world" + } +} diff --git a/tests/facts/docker.DockerPlugins/plugins.json b/tests/facts/docker.DockerPlugins/plugins.json new file mode 100644 index 000000000..f3b6baa8e --- /dev/null +++ b/tests/facts/docker.DockerPlugins/plugins.json @@ -0,0 +1,8 @@ +{ + "command": "ids=$(docker plugin ls -q) && [ -n \"$ids\" ] && docker plugin inspect $ids || echo \"[]\"", + "requires_command": "docker", + "output": ["{\"hello\": \"world\"}"], + "fact": { + "hello": "world" + } +} diff --git a/tests/facts/files.FileContents/file.json b/tests/facts/files.FileContents/file.json new file mode 100644 index 000000000..8b71de1c1 --- /dev/null +++ b/tests/facts/files.FileContents/file.json @@ -0,0 +1,6 @@ +{ + "arg": "myfile", + "command": "cat myfile", + "output": ["line1", "line2"], + "fact": ["line1", "line2"] +} diff --git a/tests/facts/files.FileContents/no_file.json b/tests/facts/files.FileContents/no_file.json new file mode 100644 index 000000000..cab556293 --- /dev/null +++ b/tests/facts/files.FileContents/no_file.json @@ -0,0 +1,6 @@ +{ + "arg": ["test"], + "command": "cat test", + "output": null, + "fact": null +} diff --git a/tests/facts/server.RebootRequired/alpine_reboot_required.json b/tests/facts/server.RebootRequired/alpine_reboot_required.json new file mode 100644 index 000000000..34dce95a3 --- /dev/null +++ b/tests/facts/server.RebootRequired/alpine_reboot_required.json @@ -0,0 +1,5 @@ +{ + "command": "\n# Get OS type\nOS_TYPE=$(uname -s)\nif [ \"$OS_TYPE\" = \"Linux\" ]; then\n # Check if it's Alpine Linux\n if [ -f /etc/alpine-release ]; then\n RUNNING_KERNEL=$(uname -r)\n INSTALLED_KERNEL=$(ls -1 /lib/modules | sort -V | tail -n1)\n if [ \"$RUNNING_KERNEL\" != \"$INSTALLED_KERNEL\" ]; then\n echo \"reboot_required\"\n exit 0\n fi\n else\n # Check standard Linux reboot required files\n if [ -f /var/run/reboot-required ] || [ -f /var/run/reboot-required.pkgs ]; then\n echo \"reboot_required\"\n exit 0\n fi\n fi\nelif [ \"$OS_TYPE\" = \"FreeBSD\" ]; then\n RUNNING_VERSION=$(freebsd-version -r)\n INSTALLED_VERSION=$(freebsd-version -k)\n if [ \"$RUNNING_VERSION\" != \"$INSTALLED_VERSION\" ]; then\n echo \"reboot_required\"\n exit 0\n fi\nfi\necho \"no_reboot_required\"\n", + "output": ["reboot_required"], + "fact": true +} diff --git a/tests/facts/server.RebootRequired/freebsd_reboot_required.json b/tests/facts/server.RebootRequired/freebsd_reboot_required.json new file mode 100644 index 000000000..efb3732a5 --- /dev/null +++ b/tests/facts/server.RebootRequired/freebsd_reboot_required.json @@ -0,0 +1,4 @@ +{ + "command": "\n# Get OS type\nOS_TYPE=$(uname -s)\nif [ \"$OS_TYPE\" = \"Linux\" ]; then\n # Check if it's Alpine Linux\n if [ -f /etc/alpine-release ]; then\n RUNNING_KERNEL=$(uname -r)\n INSTALLED_KERNEL=$(ls -1 /lib/modules | sort -V | tail -n1)\n if [ \"$RUNNING_KERNEL\" != \"$INSTALLED_KERNEL\" ]; then\n echo \"reboot_required\"\n exit 0\n fi\n else\n # Check standard Linux reboot required files\n if [ -f /var/run/reboot-required ] || [ -f /var/run/reboot-required.pkgs ]; then\n echo \"reboot_required\"\n exit 0\n fi\n fi\nelif [ \"$OS_TYPE\" = \"FreeBSD\" ]; then\n RUNNING_VERSION=$(freebsd-version -r)\n INSTALLED_VERSION=$(freebsd-version -k)\n if [ \"$RUNNING_VERSION\" != \"$INSTALLED_VERSION\" ]; then\n echo \"reboot_required\"\n exit 0\n fi\nfi\necho \"no_reboot_required\"\n", "output": ["reboot_required"], + "fact": true +} diff --git a/tests/facts/server.RebootRequired/linux_no_reboot.json b/tests/facts/server.RebootRequired/linux_no_reboot.json new file mode 100644 index 000000000..5ccba912c --- /dev/null +++ b/tests/facts/server.RebootRequired/linux_no_reboot.json @@ -0,0 +1,5 @@ +{ + "command": "\n# Get OS type\nOS_TYPE=$(uname -s)\nif [ \"$OS_TYPE\" = \"Linux\" ]; then\n # Check if it's Alpine Linux\n if [ -f /etc/alpine-release ]; then\n RUNNING_KERNEL=$(uname -r)\n INSTALLED_KERNEL=$(ls -1 /lib/modules | sort -V | tail -n1)\n if [ \"$RUNNING_KERNEL\" != \"$INSTALLED_KERNEL\" ]; then\n echo \"reboot_required\"\n exit 0\n fi\n else\n # Check standard Linux reboot required files\n if [ -f /var/run/reboot-required ] || [ -f /var/run/reboot-required.pkgs ]; then\n echo \"reboot_required\"\n exit 0\n fi\n fi\nelif [ \"$OS_TYPE\" = \"FreeBSD\" ]; then\n RUNNING_VERSION=$(freebsd-version -r)\n INSTALLED_VERSION=$(freebsd-version -k)\n if [ \"$RUNNING_VERSION\" != \"$INSTALLED_VERSION\" ]; then\n echo \"reboot_required\"\n exit 0\n fi\nfi\necho \"no_reboot_required\"\n", + "fact": false, + "output": ["no_reboot_required"] +} diff --git a/tests/facts/server.RebootRequired/linux_reboot_required.json b/tests/facts/server.RebootRequired/linux_reboot_required.json new file mode 100644 index 000000000..34dce95a3 --- /dev/null +++ b/tests/facts/server.RebootRequired/linux_reboot_required.json @@ -0,0 +1,5 @@ +{ + "command": "\n# Get OS type\nOS_TYPE=$(uname -s)\nif [ \"$OS_TYPE\" = \"Linux\" ]; then\n # Check if it's Alpine Linux\n if [ -f /etc/alpine-release ]; then\n RUNNING_KERNEL=$(uname -r)\n INSTALLED_KERNEL=$(ls -1 /lib/modules | sort -V | tail -n1)\n if [ \"$RUNNING_KERNEL\" != \"$INSTALLED_KERNEL\" ]; then\n echo \"reboot_required\"\n exit 0\n fi\n else\n # Check standard Linux reboot required files\n if [ -f /var/run/reboot-required ] || [ -f /var/run/reboot-required.pkgs ]; then\n echo \"reboot_required\"\n exit 0\n fi\n fi\nelif [ \"$OS_TYPE\" = \"FreeBSD\" ]; then\n RUNNING_VERSION=$(freebsd-version -r)\n INSTALLED_VERSION=$(freebsd-version -k)\n if [ \"$RUNNING_VERSION\" != \"$INSTALLED_VERSION\" ]; then\n echo \"reboot_required\"\n exit 0\n fi\nfi\necho \"no_reboot_required\"\n", + "output": ["reboot_required"], + "fact": true +} diff --git a/tests/facts/yum.YumRepositories/repos.json b/tests/facts/yum.YumRepositories/repos.json index d310a861d..091ac312c 100644 --- a/tests/facts/yum.YumRepositories/repos.json +++ b/tests/facts/yum.YumRepositories/repos.json @@ -10,6 +10,7 @@ ], "fact": [ { + "repoid": "somerepo", "name": "somerepo", "baseurl": "abc" } diff --git a/tests/facts/zypper.ZypperRepositories/repos.json b/tests/facts/zypper.ZypperRepositories/repos.json index 0183262c4..7db623ee5 100644 --- a/tests/facts/zypper.ZypperRepositories/repos.json +++ b/tests/facts/zypper.ZypperRepositories/repos.json @@ -10,6 +10,7 @@ ], "fact": [ { + "repoid": "somerepo", "name": "somerepo", "baseurl": "abc" } diff --git a/tests/operations/docker.plugin/disable_plugin.json b/tests/operations/docker.plugin/disable_plugin.json new file mode 100644 index 000000000..70b63023e --- /dev/null +++ b/tests/operations/docker.plugin/disable_plugin.json @@ -0,0 +1,20 @@ +{ + "kwargs": { + "plugin": "my-plugin", + "enabled": false + }, + "facts": { + "docker.DockerPlugin": { + "object_id=my-plugin": [ + { + "Name": "my-plugin:latest", + "Enabled": true, + "Id": "1234567890abcdef" + } + ] + } + }, + "commands": [ + "docker plugin disable my-plugin" + ] +} \ No newline at end of file diff --git a/tests/operations/docker.plugin/enable_plugin.json b/tests/operations/docker.plugin/enable_plugin.json new file mode 100644 index 000000000..a449cbdcb --- /dev/null +++ b/tests/operations/docker.plugin/enable_plugin.json @@ -0,0 +1,19 @@ +{ + "kwargs": { + "plugin": "my-plugin" + }, + "facts": { + "docker.DockerPlugin": { + "object_id=my-plugin": [ + { + "Name": "my-plugin:latest", + "Enabled": false, + "Id": "1234567890abcdef" + } + ] + } + }, + "commands": [ + "docker plugin enable my-plugin" + ] +} \ No newline at end of file diff --git a/tests/operations/docker.plugin/install_plugin.json b/tests/operations/docker.plugin/install_plugin.json new file mode 100644 index 000000000..0db2baddd --- /dev/null +++ b/tests/operations/docker.plugin/install_plugin.json @@ -0,0 +1,13 @@ +{ + "kwargs": { + "plugin": "username/my-awesome-plugin:latest" + }, + "facts": { + "docker.DockerPlugin": { + "object_id=username/my-awesome-plugin:latest": [] + } + }, + "commands": [ + "docker plugin install username/my-awesome-plugin:latest --grant-all-permissions" + ] +} \ No newline at end of file diff --git a/tests/operations/docker.plugin/install_plugin_disabled.json b/tests/operations/docker.plugin/install_plugin_disabled.json new file mode 100644 index 000000000..694b68c7a --- /dev/null +++ b/tests/operations/docker.plugin/install_plugin_disabled.json @@ -0,0 +1,14 @@ +{ + "kwargs": { + "plugin": "username/my-awesome-plugin:latest", + "enabled": false + }, + "facts": { + "docker.DockerPlugin": { + "object_id=username/my-awesome-plugin:latest": [] + } + }, + "commands": [ + "docker plugin install username/my-awesome-plugin:latest --grant-all-permissions --disable" + ] +} \ No newline at end of file diff --git a/tests/operations/docker.plugin/install_plugin_with_alias_and_options.json b/tests/operations/docker.plugin/install_plugin_with_alias_and_options.json new file mode 100644 index 000000000..e7f074875 --- /dev/null +++ b/tests/operations/docker.plugin/install_plugin_with_alias_and_options.json @@ -0,0 +1,18 @@ +{ + "kwargs": { + "plugin": "username/my-awesome-plugin:latest", + "alias": "my-plugin", + "plugin_options": { + "MY_VAR1": "value1", + "MY_VAR2": "value2" + } + }, + "facts": { + "docker.DockerPlugin": { + "object_id=my-plugin": [] + } + }, + "commands": [ + "docker plugin install username/my-awesome-plugin:latest --grant-all-permissions --alias my-plugin MY_VAR1=value1 MY_VAR2=value2" + ] +} \ No newline at end of file diff --git a/tests/operations/docker.plugin/remove_plugin.json b/tests/operations/docker.plugin/remove_plugin.json new file mode 100644 index 000000000..4a74142eb --- /dev/null +++ b/tests/operations/docker.plugin/remove_plugin.json @@ -0,0 +1,20 @@ +{ + "kwargs": { + "plugin": "my-plugin", + "present": false + }, + "facts": { + "docker.DockerPlugin": { + "object_id=my-plugin": [ + { + "Name": "my-plugin:latest", + "Enabled": true, + "Id": "1234567890abcdef" + } + ] + } + }, + "commands": [ + "docker plugin rm -f my-plugin" + ] +} \ No newline at end of file diff --git a/tests/operations/docker.plugin/set_plugin_options.json b/tests/operations/docker.plugin/set_plugin_options.json new file mode 100644 index 000000000..ced5675a2 --- /dev/null +++ b/tests/operations/docker.plugin/set_plugin_options.json @@ -0,0 +1,31 @@ +{ + "kwargs": { + "plugin": "my-plugin", + "plugin_options": { + "MY_VAR1": "value1", + "MY_VAR2": "value3" + } + }, + "facts": { + "docker.DockerPlugin": { + "object_id=my-plugin": [ + { + "Name": "my-plugin:latest", + "Enabled": true, + "Id": "1234567890abcdef", + "Settings": { + "Env": { + "MY_VAR1": "value1", + "MY_VAR2": "value2" + } + } + } + ] + } + }, + "commands": [ + "docker plugin disable my-plugin", + "docker plugin set my-plugin MY_VAR1=value1 MY_VAR2=value3", + "docker plugin enable my-plugin" + ] +} \ No newline at end of file diff --git a/tests/operations/docker.plugin/set_plugin_options_disabled_plugin.json b/tests/operations/docker.plugin/set_plugin_options_disabled_plugin.json new file mode 100644 index 000000000..f57721a8d --- /dev/null +++ b/tests/operations/docker.plugin/set_plugin_options_disabled_plugin.json @@ -0,0 +1,30 @@ +{ + "kwargs": { + "plugin": "my-plugin", + "enabled": false, + "plugin_options": { + "MY_VAR1": "value1", + "MY_VAR2": "value3" + } + }, + "facts": { + "docker.DockerPlugin": { + "object_id=my-plugin": [ + { + "Name": "my-plugin:latest", + "Enabled": false, + "Id": "1234567890abcdef", + "Settings": { + "Env": { + "MY_VAR1": "value1", + "MY_VAR2": "value2" + } + } + } + ] + } + }, + "commands": [ + "docker plugin set my-plugin MY_VAR1=value1 MY_VAR2=value3" + ] +} \ No newline at end of file diff --git a/tests/operations/files.block/add_existing_block_different_content.json b/tests/operations/files.block/add_existing_block_different_content.json index b496f46a4..a7d5a92ae 100644 --- a/tests/operations/files.block/add_existing_block_different_content.json +++ b/tests/operations/files.block/add_existing_block_different_content.json @@ -12,6 +12,6 @@ } }, "commands": [ - "OUT=\"$(TMPDIR=/tmp mktemp -t pyinfra.XXXXXX)\" && awk 'BEGIN {{f=1; x=ARGV[2]; ARGV[2]=\"\"}}/# BEGIN PYINFRA BLOCK/ {print; print x; f=0} /# END PYINFRA BLOCK/ {print; f=1; next} f' /home/someone/something \"should be this\" > $OUT && chmod $(stat -c %a /home/someone/something 2>/dev/null || stat -f %Lp /home/someone/something ) $OUT && (chown $(stat -c \"%u:%g\" /home/someone/something 2>/dev/null || stat -f \"%u:%g\" /home/someone/something 2>/dev/null ) $OUT) && mv \"$OUT\" /home/someone/something" + "OUT=\"$(TMPDIR:=_tempdir_ mktemp -t pyinfra.XXXXXX)\" && awk 'BEGIN {{f=1; x=ARGV[2]; ARGV[2]=\"\"}}/# BEGIN PYINFRA BLOCK/ {print; print x; f=0} /# END PYINFRA BLOCK/ {print; f=1; next} f' /home/someone/something \"should be this\" > $OUT && chmod $(stat -c %a /home/someone/something 2>/dev/null || stat -f %Lp /home/someone/something ) $OUT && (chown $(stat -c \"%u:%g\" /home/someone/something 2>/dev/null || stat -f \"%u:%g\" /home/someone/something 2>/dev/null ) $OUT) && mv \"$OUT\" /home/someone/something" ] } diff --git a/tests/operations/files.block/add_existing_block_different_content_and_backup.json b/tests/operations/files.block/add_existing_block_different_content_and_backup.json index b1bcbacd1..3bb8301af 100644 --- a/tests/operations/files.block/add_existing_block_different_content_and_backup.json +++ b/tests/operations/files.block/add_existing_block_different_content_and_backup.json @@ -13,6 +13,6 @@ } }, "commands": [ - "cp /home/someone/something /home/someone/something.a-timestamp && OUT=\"$(TMPDIR=/tmp mktemp -t pyinfra.XXXXXX)\" && awk 'BEGIN {{f=1; x=ARGV[2]; ARGV[2]=\"\"}}/# BEGIN PYINFRA BLOCK/ {print; print x; f=0} /# END PYINFRA BLOCK/ {print; f=1; next} f' /home/someone/something \"should be this\" > $OUT && chmod $(stat -c %a /home/someone/something 2>/dev/null || stat -f %Lp /home/someone/something ) $OUT && (chown $(stat -c \"%u:%g\" /home/someone/something 2>/dev/null || stat -f \"%u:%g\" /home/someone/something 2>/dev/null ) $OUT) && mv \"$OUT\" /home/someone/something" + "cp /home/someone/something /home/someone/something.a-timestamp && OUT=\"$(TMPDIR:=_tempdir_ mktemp -t pyinfra.XXXXXX)\" && awk 'BEGIN {{f=1; x=ARGV[2]; ARGV[2]=\"\"}}/# BEGIN PYINFRA BLOCK/ {print; print x; f=0} /# END PYINFRA BLOCK/ {print; f=1; next} f' /home/someone/something \"should be this\" > $OUT && chmod $(stat -c %a /home/someone/something 2>/dev/null || stat -f %Lp /home/someone/something ) $OUT && (chown $(stat -c \"%u:%g\" /home/someone/something 2>/dev/null || stat -f \"%u:%g\" /home/someone/something 2>/dev/null ) $OUT) && mv \"$OUT\" /home/someone/something" ] } diff --git a/tests/operations/files.block/add_no_existing_block_line_provided.json b/tests/operations/files.block/add_no_existing_block_line_provided.json index d3e1fe3c8..f09d25549 100644 --- a/tests/operations/files.block/add_no_existing_block_line_provided.json +++ b/tests/operations/files.block/add_no_existing_block_line_provided.json @@ -12,6 +12,6 @@ } }, "commands": [ - "OUT=\"$(TMPDIR=/tmp mktemp -t pyinfra.XXXXXX)\" && awk 'BEGIN {x=ARGV[2]; ARGV[2]=\"\"} f!=1 && /^.*before this.*$/ { print x; f=1} END {if (f==0) print ARGV[2] } { print }' /home/someone/something \"# BEGIN PYINFRA BLOCK\nplease add this\n# END PYINFRA BLOCK\" > $OUT && chmod $(stat -c %a /home/someone/something 2>/dev/null || stat -f %Lp /home/someone/something ) $OUT && (chown $(stat -c \"%u:%g\" /home/someone/something 2>/dev/null || stat -f \"%u:%g\" /home/someone/something 2>/dev/null ) $OUT) && mv \"$OUT\" /home/someone/something" + "OUT=\"$(TMPDIR:=_tempdir_ mktemp -t pyinfra.XXXXXX)\" && awk 'BEGIN {x=ARGV[2]; ARGV[2]=\"\"} f!=1 && /^.*before this.*$/ { print x; f=1} END {if (f==0) print ARGV[2] } { print }' /home/someone/something \"# BEGIN PYINFRA BLOCK\nplease add this\n# END PYINFRA BLOCK\" > $OUT && chmod $(stat -c %a /home/someone/something 2>/dev/null || stat -f %Lp /home/someone/something ) $OUT && (chown $(stat -c \"%u:%g\" /home/someone/something 2>/dev/null || stat -f \"%u:%g\" /home/someone/something 2>/dev/null ) $OUT) && mv \"$OUT\" /home/someone/something" ] } diff --git a/tests/operations/files.block/add_no_existing_block_line_provided_escape_regex.json b/tests/operations/files.block/add_no_existing_block_line_provided_escape_regex.json index fca4d4d76..5ecb933e4 100644 --- a/tests/operations/files.block/add_no_existing_block_line_provided_escape_regex.json +++ b/tests/operations/files.block/add_no_existing_block_line_provided_escape_regex.json @@ -13,6 +13,6 @@ } }, "commands": [ - "OUT=\"$(TMPDIR=/tmp mktemp -t pyinfra.XXXXXX)\" && awk 'BEGIN {x=ARGV[2]; ARGV[2]=\"\"} f!=1 && /^.*before this \\*.*$/ { print x; f=1} END {if (f==0) print ARGV[2] } { print }' /home/someone/something \"# BEGIN PYINFRA BLOCK\nplease add this\n# END PYINFRA BLOCK\" > $OUT && chmod $(stat -c %a /home/someone/something 2>/dev/null || stat -f %Lp /home/someone/something ) $OUT && (chown $(stat -c \"%u:%g\" /home/someone/something 2>/dev/null || stat -f \"%u:%g\" /home/someone/something 2>/dev/null ) $OUT) && mv \"$OUT\" /home/someone/something" + "OUT=\"$(TMPDIR:=_tempdir_ mktemp -t pyinfra.XXXXXX)\" && awk 'BEGIN {x=ARGV[2]; ARGV[2]=\"\"} f!=1 && /^.*before this \\*.*$/ { print x; f=1} END {if (f==0) print ARGV[2] } { print }' /home/someone/something \"# BEGIN PYINFRA BLOCK\nplease add this\n# END PYINFRA BLOCK\" > $OUT && chmod $(stat -c %a /home/someone/something 2>/dev/null || stat -f %Lp /home/someone/something ) $OUT && (chown $(stat -c \"%u:%g\" /home/someone/something 2>/dev/null || stat -f \"%u:%g\" /home/someone/something 2>/dev/null ) $OUT) && mv \"$OUT\" /home/someone/something" ] } diff --git a/tests/operations/files.block/remove_but_content_not_none.json b/tests/operations/files.block/remove_but_content_not_none.json index 870a92509..ca72683f6 100644 --- a/tests/operations/files.block/remove_but_content_not_none.json +++ b/tests/operations/files.block/remove_but_content_not_none.json @@ -11,6 +11,6 @@ } }, "commands": [ - "OUT=\"$(TMPDIR=/tmp mktemp -t pyinfra.XXXXXX)\" && awk '/# BEGIN PYINFRA BLOCK/,/# END PYINFRA BLOCK/ {next} 1' /home/someone/something > $OUT && chmod $(stat -c %a /home/someone/something 2>/dev/null || stat -f %Lp /home/someone/something ) $OUT && (chown $(stat -c \"%u:%g\" /home/someone/something 2>/dev/null || stat -f \"%u:%g\" /home/someone/something 2>/dev/null ) $OUT) && mv \"$OUT\" /home/someone/something" + "OUT=\"$(TMPDIR:=_tempdir_ mktemp -t pyinfra.XXXXXX)\" && awk '/# BEGIN PYINFRA BLOCK/,/# END PYINFRA BLOCK/ {next} 1' /home/someone/something > $OUT && chmod $(stat -c %a /home/someone/something 2>/dev/null || stat -f %Lp /home/someone/something ) $OUT && (chown $(stat -c \"%u:%g\" /home/someone/something 2>/dev/null || stat -f \"%u:%g\" /home/someone/something 2>/dev/null ) $OUT) && mv \"$OUT\" /home/someone/something" ] } diff --git a/tests/operations/files.block/remove_existing_block.json b/tests/operations/files.block/remove_existing_block.json index f77e8d3da..69ba79905 100644 --- a/tests/operations/files.block/remove_existing_block.json +++ b/tests/operations/files.block/remove_existing_block.json @@ -10,6 +10,6 @@ } }, "commands": [ - "OUT=\"$(TMPDIR=/tmp mktemp -t pyinfra.XXXXXX)\" && awk '/# BEGIN PYINFRA BLOCK/,/# END PYINFRA BLOCK/ {next} 1' /home/someone/something > $OUT && chmod $(stat -c %a /home/someone/something 2>/dev/null || stat -f %Lp /home/someone/something ) $OUT && (chown $(stat -c \"%u:%g\" /home/someone/something 2>/dev/null || stat -f \"%u:%g\" /home/someone/something 2>/dev/null ) $OUT) && mv \"$OUT\" /home/someone/something" + "OUT=\"$(TMPDIR:=_tempdir_ mktemp -t pyinfra.XXXXXX)\" && awk '/# BEGIN PYINFRA BLOCK/,/# END PYINFRA BLOCK/ {next} 1' /home/someone/something > $OUT && chmod $(stat -c %a /home/someone/something 2>/dev/null || stat -f %Lp /home/someone/something ) $OUT && (chown $(stat -c \"%u:%g\" /home/someone/something 2>/dev/null || stat -f \"%u:%g\" /home/someone/something 2>/dev/null ) $OUT) && mv \"$OUT\" /home/someone/something" ] } diff --git a/tests/operations/files.put/different_remote.json b/tests/operations/files.put/different_remote.json index 683961d8e..610985c6f 100644 --- a/tests/operations/files.put/different_remote.json +++ b/tests/operations/files.put/different_remote.json @@ -24,6 +24,9 @@ }, "files.Sha1File": { "path=/home/somefile.txt": "nowt" + }, + "files.FileContents": { + "path=/home/somefile.txt": [] } }, "commands": [ diff --git a/tests/operations/files.put/fallback_md5.json b/tests/operations/files.put/fallback_md5.json index ade86b3f5..3b4ce62c8 100644 --- a/tests/operations/files.put/fallback_md5.json +++ b/tests/operations/files.put/fallback_md5.json @@ -20,9 +20,12 @@ }, "files.Md5File": { "path=/home/somefile.txt": "nowt" - } + }, + "files.FileContents": { + "path=/home/somefile.txt": null + }, }, "commands": [ ["upload", "/somefile.txt", "/home/somefile.txt"] ] -} \ No newline at end of file +} diff --git a/tests/operations/files.put/fallback_sha256.json b/tests/operations/files.put/fallback_sha256.json index d1caffac8..7f66b6729 100644 --- a/tests/operations/files.put/fallback_sha256.json +++ b/tests/operations/files.put/fallback_sha256.json @@ -23,9 +23,12 @@ }, "files.Sha256File": { "path=/home/somefile.txt": "nowt" - } + }, + "files.FileContents": { + "path=/home/somefile.txt": null + }, }, "commands": [ ["upload", "/somefile.txt", "/home/somefile.txt"] ] -} \ No newline at end of file +} diff --git a/tests/operations/pip.packages/add_existing_package_with_url.json b/tests/operations/pip.packages/add_existing_package_with_url.json new file mode 100644 index 000000000..a59066e63 --- /dev/null +++ b/tests/operations/pip.packages/add_existing_package_with_url.json @@ -0,0 +1,13 @@ +{ + "args": ["copier @ git@github.com:copier-org/copier.git@v9.9.0"], + "facts": { + "pip.PipPackages": { + "pip=pip": { + "copier": ["9.9.0"] + } + } + }, + "commands": [ + ], + "noop_description": "package copier is installed (9.9.0)" +} diff --git a/tests/operations/pipx.packages/add_existing_package_with_url.json b/tests/operations/pipx.packages/add_existing_package_with_url.json new file mode 100644 index 000000000..b07ffd3d9 --- /dev/null +++ b/tests/operations/pipx.packages/add_existing_package_with_url.json @@ -0,0 +1,9 @@ +{ + "args": ["copier @ git@github.com:copier-org/copier.git@v9.9.0"], + "facts": { + "pipx.PipxPackages": {"copier": ["9.9.0"]} + }, + "commands": [ + ], + "noop_description": "package copier is installed (9.9.0)" +} diff --git a/tests/operations/pipx.packages/add_nothing.json b/tests/operations/pipx.packages/add_nothing.json new file mode 100644 index 000000000..891097f42 --- /dev/null +++ b/tests/operations/pipx.packages/add_nothing.json @@ -0,0 +1,9 @@ +{ + "args": [null], + "facts": { + "pipx.PipxPackages": {"copier": ["9.9.0"]} + }, + "commands": [ + ], + "noop_description": "no package list provided to pipx.packages" +} diff --git a/tests/operations/pipx.packages/add_package.json b/tests/operations/pipx.packages/add_package.json new file mode 100644 index 000000000..ddb548789 --- /dev/null +++ b/tests/operations/pipx.packages/add_package.json @@ -0,0 +1,9 @@ +{ + "args": ["copier==0.9.1"], + "facts": { + "pipx.PipxPackages": {"ensurepath": ["0.1.1"]} + }, + "commands": [ + "pipx install copier==0.9.1" + ] +} diff --git a/tests/operations/server.user/keys_delete.json b/tests/operations/server.user/keys_delete.json index 0223ad573..3e2089c12 100644 --- a/tests/operations/server.user/keys_delete.json +++ b/tests/operations/server.user/keys_delete.json @@ -44,6 +44,9 @@ "files.Sha256File": { "path=homedir/.ssh/authorized_keys": null }, + "files.FileContents": { + "path=homedir/.ssh/authorized_keys": null + }, "server.Groups": {} }, "commands": [ diff --git a/tests/test_api/test_api_operations.py b/tests/test_api/test_api_operations.py index 3bee3dc74..f4a7542ba 100644 --- a/tests/test_api/test_api_operations.py +++ b/tests/test_api/test_api_operations.py @@ -19,6 +19,7 @@ from pyinfra.api.operation import OperationMeta, add_op from pyinfra.api.operations import run_ops from pyinfra.api.state import StateOperationMeta +from pyinfra.connectors.util import CommandOutput, OutputLine from pyinfra.context import ctx_host, ctx_state from pyinfra.operations import files, python, server @@ -576,4 +577,351 @@ def add_another_op(): assert op_order[1] == second_op_hash +class TestOperationRetry(PatchSSHTestCase): + """ + Tests for the retry functionality in operations. + """ + + @patch("pyinfra.connectors.ssh.SSHConnector.run_shell_command") + def test_basic_retry_behavior(self, fake_run_command): + """ + Test that operations retry the correct number of times on failure. + """ + # Create inventory with just one host to simplify testing + inventory = make_inventory(hosts=("somehost",)) + state = State(inventory, Config()) + connect_all(state) + + # Add operation with retry settings + add_op( + state, + server.shell, + 'echo "testing retries"', + _retries=2, + _retry_delay=0.1, # Use small delay for tests + ) + + # Track how many times run_shell_command was called + call_count = 0 + + # First call fails, second succeeds + def side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call fails + fake_channel = FakeChannel(1) + return (False, FakeBuffer("", fake_channel)) + else: + # Second call succeeds + fake_channel = FakeChannel(0) + return (True, FakeBuffer("success", fake_channel)) + + fake_run_command.side_effect = side_effect + + # Run the operation + run_ops(state) + + # Check that run_shell_command was called twice (original + 1 retry) + self.assertEqual(call_count, 2) + + # Verify results + somehost = inventory.get_host("somehost") + + # Operation should be successful (because the retry succeeded) + self.assertEqual(state.results[somehost].success_ops, 1) + self.assertEqual(state.results[somehost].error_ops, 0) + + # Get the operation hash + op_hash = state.get_op_order()[0] + + # Check retry info in OperationMeta + op_meta = state.ops[somehost][op_hash].operation_meta + self.assertEqual(op_meta.retry_attempts, 1) + self.assertEqual(op_meta.max_retries, 2) + self.assertTrue(op_meta.was_retried) + self.assertTrue(op_meta.retry_succeeded) + + @patch("pyinfra.connectors.ssh.SSHConnector.run_shell_command") + def test_retry_max_attempts_failure(self, fake_run_command): + """ + Test that operations stop retrying after max attempts and report failure. + """ + inventory = make_inventory(hosts=("somehost",)) + state = State(inventory, Config()) + connect_all(state) + + # Add operation with retry settings + add_op( + state, + server.shell, + 'echo "testing max retries"', + _retries=2, + _retry_delay=0.1, + ) + + # Make all attempts fail + fake_channel = FakeChannel(1) + fake_run_command.return_value = (False, FakeBuffer("", fake_channel)) + + # This should fail after all retries + with self.assertRaises(PyinfraError) as e: + run_ops(state) + + self.assertEqual(e.exception.args[0], "No hosts remaining!") + + # Check that run_shell_command was called the right number of times (1 original + 2 retries) + self.assertEqual(fake_run_command.call_count, 3) + + somehost = inventory.get_host("somehost") + + # Operation should be marked as error + self.assertEqual(state.results[somehost].success_ops, 0) + self.assertEqual(state.results[somehost].error_ops, 1) + + # Get the operation hash + op_hash = state.get_op_order()[0] + + # Check retry info + op_meta = state.ops[somehost][op_hash].operation_meta + self.assertEqual(op_meta.retry_attempts, 2) + self.assertEqual(op_meta.max_retries, 2) + self.assertTrue(op_meta.was_retried) + self.assertFalse(op_meta.retry_succeeded) + + @patch("pyinfra.connectors.ssh.SSHConnector.run_shell_command") + @patch("time.sleep") + def test_retry_until_condition(self, fake_sleep, fake_run_command): + """ + Test that operations retry based on the retry_until callable condition. + """ + # Setup inventory and state using the utility function + inventory = make_inventory(hosts=("somehost",)) + state = State(inventory, Config()) + connect_all(state) + + # Create a counter to track retry_until calls + call_counter = [0] + + # Create a retry_until function that returns True (retry) for first two calls + def retry_until_func(output_data): + call_counter[0] += 1 + return call_counter[0] < 3 # Retry twice, then stop + + # Add operation with retry_until + add_op( + state, + server.shell, + 'echo "test retry_until"', + _retries=3, + _retry_delay=0.1, + _retry_until=retry_until_func, + ) + + # Set up fake command execution - always succeed but with proper output format + # Use the existing FakeBuffer/FakeChannel from test utils + + # First two calls trigger retry_until, third doesn't + def command_side_effect(*args, **kwargs): + # Create proper CommandOutput for the retry_until function to process + lines = [OutputLine("stdout", "test output"), OutputLine("stderr", "no errors")] + return True, CommandOutput(lines) + + fake_run_command.side_effect = command_side_effect + + # Run the operations + run_ops(state) + + # The command should be called 3 times total (initial + 2 retries) + self.assertEqual(fake_run_command.call_count, 3) + + # The retry_until function should be called 3 times + self.assertEqual(call_counter[0], 3) + + # Get the operation metadata to check retry info + somehost = inventory.get_host("somehost") + op_hash = state.get_op_order()[0] + op_meta = state.ops[somehost][op_hash].operation_meta + + # Check retry metadata + self.assertEqual(op_meta.retry_attempts, 2) + self.assertEqual(op_meta.max_retries, 3) + self.assertTrue(op_meta.was_retried) + self.assertTrue(op_meta.retry_succeeded) + + @patch("pyinfra.connectors.ssh.SSHConnector.run_shell_command") + @patch("time.sleep") + def test_retry_delay(self, fake_sleep, fake_run_command): + """ + Test that retry delay is properly applied between attempts. + """ + inventory = make_inventory(hosts=("somehost",)) + state = State(inventory, Config()) + connect_all(state) + + retry_delay = 5 + + # Add operation with retry settings + add_op( + state, + server.shell, + 'echo "testing retry delay"', + _retries=2, + _retry_delay=retry_delay, + ) + + # Make first call fail, second succeed + call_count = 0 + + def side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + fake_channel = FakeChannel(1) + return (False, FakeBuffer("", fake_channel)) + else: + fake_channel = FakeChannel(0) + return (True, FakeBuffer("", fake_channel)) + + fake_run_command.side_effect = side_effect + + # Run the operation + run_ops(state) + + # Check that sleep was called with the correct delay + fake_sleep.assert_called_once_with(retry_delay) + + @patch("pyinfra.connectors.ssh.SSHConnector.run_shell_command") + @patch("time.sleep") + def test_retry_until_with_error_handling(self, fake_sleep, fake_run_command): + """ + Test that operations handle errors in retry_until functions gracefully. + """ + inventory = make_inventory(hosts=("somehost",)) + state = State(inventory, Config()) + connect_all(state) + + # Create a retry_until function that raises an exception + def failing_retry_until_func(output_data): + raise ValueError("Test error in retry_until function") + + # Add operation with failing retry_until + add_op( + state, + server.shell, + 'echo "test failing retry_until"', + _retries=2, + _retry_delay=0.1, + _retry_until=failing_retry_until_func, + ) + + # Set up fake command execution + + def command_side_effect(*args, **kwargs): + lines = [OutputLine("stdout", "test output"), OutputLine("stderr", "no errors")] + return True, CommandOutput(lines) + + fake_run_command.side_effect = command_side_effect + + # Run the operations - should succeed despite retry_until error + run_ops(state) + + # The command should be called only once (no retries due to error) + self.assertEqual(fake_run_command.call_count, 1) + + # Verify operation completed successfully + somehost = inventory.get_host("somehost") + self.assertEqual(state.results[somehost].success_ops, 1) + self.assertEqual(state.results[somehost].error_ops, 0) + + @patch("pyinfra.connectors.ssh.SSHConnector.run_shell_command") + @patch("time.sleep") + def test_retry_until_with_complex_output_parsing(self, fake_sleep, fake_run_command): + """ + Test retry_until with complex output parsing scenarios. + """ + inventory = make_inventory(hosts=("somehost",)) + state = State(inventory, Config()) + connect_all(state) + + # Track what output we've seen + outputs_seen = [] + + def complex_retry_until_func(output_data): + # Store the output data for verification + outputs_seen.append(output_data) + + # Check for specific patterns in stdout + stdout_text = " ".join(output_data["stdout_lines"]) + + # Continue retrying until we see "READY" in stdout + return "READY" not in stdout_text + + # Add operation with complex retry_until + add_op( + state, + server.shell, + 'echo "service status check"', + _retries=3, + _retry_delay=0.1, + _retry_until=complex_retry_until_func, + ) + + # Set up fake command execution with changing output + + call_count = 0 + + def command_side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + + if call_count == 1: + lines = [ + OutputLine("stdout", "Service starting..."), + OutputLine("stderr", "Loading config"), + ] + elif call_count == 2: + lines = [ + OutputLine("stdout", "Service initializing..."), + OutputLine("stderr", "Connecting to database"), + ] + else: # call_count == 3 + lines = [ + OutputLine("stdout", "Service READY"), + OutputLine("stderr", "All systems operational"), + ] + + return True, CommandOutput(lines) + + fake_run_command.side_effect = command_side_effect + + # Run the operations + run_ops(state) + + # The command should be called 3 times + self.assertEqual(fake_run_command.call_count, 3) + + # Verify retry_until was called 3 times with correct data + self.assertEqual(len(outputs_seen), 3) + + # Check the output data structure + for output_data in outputs_seen: + self.assertIn("stdout_lines", output_data) + self.assertIn("stderr_lines", output_data) + self.assertIn("commands", output_data) + self.assertIn("executed_commands", output_data) + self.assertIn("host", output_data) + self.assertIn("operation", output_data) + + # Verify operation metadata + somehost = inventory.get_host("somehost") + op_hash = state.get_op_order()[0] + op_meta = state.ops[somehost][op_hash].operation_meta + + self.assertEqual(op_meta.retry_attempts, 2) + self.assertEqual(op_meta.max_retries, 3) + self.assertTrue(op_meta.was_retried) + self.assertTrue(op_meta.retry_succeeded) + + this_filename = path.join("tests", "test_api", "test_api_operations.py") diff --git a/tests/test_cli/test_cli.py b/tests/test_cli/test_cli.py index c2677864f..96ae936ac 100644 --- a/tests/test_cli/test_cli.py +++ b/tests/test_cli/test_cli.py @@ -188,5 +188,8 @@ def test_deploy_operation_direct(self): debug_all=False, debug_operations=False, config_filename="config.py", + diff=True, + retry=0, + retry_delay=5, ) assert e.args == (0,) diff --git a/tests/util.py b/tests/util.py index 100c5e169..3ae0b7488 100644 --- a/tests/util.py +++ b/tests/util.py @@ -190,6 +190,19 @@ def noop(self, description): def get_temp_filename(*args, **kwargs): return "_tempfile_" + def get_temp_dir_config(*args, **kwargs): + return "_tempdir_" + + def get_file( + self, + remote_filename, + filename_or_io, + remote_temp_filename=None, + print_output=False, + *arguments, + ): + return True + @staticmethod def _get_fact_key(fact_cls): return "{0}.{1}".format(fact_cls.__module__.split(".")[-1], fact_cls.__name__)