Skip to content
Merged
17 changes: 15 additions & 2 deletions pyinfra/api/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import IO, TYPE_CHECKING, Callable, Union

import gevent
from typing_extensions import Unpack
from typing_extensions import Unpack, override

from pyinfra.context import LocalContextObject, ctx_config, ctx_host

Expand Down Expand Up @@ -58,6 +58,7 @@ class QuoteString:
def __init__(self, obj: Union[str, "StringCommand"]):
self.obj = obj

@override
def __repr__(self) -> str:
return f"QuoteString({self.obj})"

Expand All @@ -68,6 +69,7 @@ class PyinfraCommand:
def __init__(self, **arguments: Unpack[ConnectorArguments]):
self.connector_arguments = arguments

@override
def __eq__(self, other) -> bool:
if isinstance(other, self.__class__) and repr(self) == repr(other):
return True
Expand All @@ -88,9 +90,11 @@ def __init__(
self.bits = bits
self.separator = _separator

@override
def __str__(self) -> str:
return self.get_masked_value()

@override
def __repr__(self) -> str:
return f"StringCommand({self.get_masked_value()})"

Expand Down Expand Up @@ -131,6 +135,7 @@ def get_masked_value(self) -> str:
],
)

@override
def execute(self, state: "State", host: "Host", connector_arguments: ConnectorArguments):
connector_arguments.update(self.connector_arguments)

Expand All @@ -155,9 +160,11 @@ def __init__(
self.dest = dest
self.remote_temp_filename = remote_temp_filename

@override
def __repr__(self):
return "FileUploadCommand({0}, {1})".format(self.src, self.dest)

@override
def execute(self, state: "State", host: "Host", connector_arguments: ConnectorArguments):
connector_arguments.update(self.connector_arguments)

Expand All @@ -184,9 +191,11 @@ def __init__(
self.dest = dest
self.remote_temp_filename = remote_temp_filename

@override
def __repr__(self):
return "FileDownloadCommand({0}, {1})".format(self.src, self.dest)

@override
def execute(self, state: "State", host: "Host", connector_arguments: ConnectorArguments):
connector_arguments.update(self.connector_arguments)

Expand All @@ -213,13 +222,15 @@ def __init__(
self.args = args
self.kwargs = func_kwargs

@override
def __repr__(self):
return "FunctionCommand({0}, {1}, {2})".format(
self.function.__name__,
self.args,
self.kwargs,
)

@override
def execute(self, state: "State", host: "Host", connector_arguments: ConnectorArguments):
argspec = getfullargspec(self.function)
if "state" in argspec.args and "host" in argspec.args:
Expand All @@ -231,7 +242,7 @@ def execute(self, state: "State", host: "Host", connector_arguments: ConnectorAr
self.function(*self.args, **self.kwargs)
return

def execute_function():
def execute_function() -> None:
with ctx_config.use(state.config.copy()):
with ctx_host.use(host):
self.function(*self.args, **self.kwargs)
Expand All @@ -247,9 +258,11 @@ def __init__(self, src: str, dest: str, flags, **kwargs: Unpack[ConnectorArgumen
self.dest = dest
self.flags = flags

@override
def __repr__(self):
return "RsyncCommand({0}, {1}, {2})".format(self.src, self.dest, self.flags)

@override
def execute(self, state: "State", host: "Host", connector_arguments: ConnectorArguments):
return host.rsync(
self.src,
Expand Down
7 changes: 5 additions & 2 deletions pyinfra/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import importlib_metadata
except ImportError:
import importlib.metadata as importlib_metadata # type: ignore[no-redef]

from os import path
from typing import Iterable, Optional, Set

from packaging.markers import Marker
from packaging.requirements import Requirement
from packaging.specifiers import SpecifierSet
from packaging.version import Version
from typing_extensions import override

from pyinfra import __version__, state

Expand Down Expand Up @@ -207,6 +209,7 @@ def __init__(self, **kwargs):
for key, value in config.items():
setattr(self, key, value)

@override
def __setattr__(self, key, value):
super().__setattr__(key, value)

Expand All @@ -221,10 +224,10 @@ def set_current_state(self, config_state):
for key, value in config_state:
setattr(self, key, value)

def lock_current_state(self):
def lock_current_state(self) -> None:
self._locked_config = self.get_current_state()

def reset_locked_state(self):
def reset_locked_state(self) -> None:
self.set_current_state(self._locked_config)

def copy(self) -> "Config":
Expand Down
3 changes: 3 additions & 0 deletions pyinfra/api/facts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import click
import gevent
from paramiko import SSHException
from typing_extensions import override

from pyinfra import logger
from pyinfra.api import StringCommand
Expand Down Expand Up @@ -61,6 +62,7 @@ class FactBase(Generic[T]):
def requires_command(self, *args, **kwargs) -> str | None:
return None

@override
def __init_subclass__(cls) -> None:
super().__init_subclass__()
module_name = cls.__module__.replace("pyinfra.facts.", "")
Expand Down Expand Up @@ -97,6 +99,7 @@ class ShortFactBase(Generic[T]):
name: str
fact: Type[FactBase]

@override
def __init_subclass__(cls) -> None:
super().__init_subclass__()
module_name = cls.__module__.replace("pyinfra.facts.", "")
Expand Down
10 changes: 7 additions & 3 deletions pyinfra/api/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from uuid import uuid4

import click
from typing_extensions import Unpack
from typing_extensions import Unpack, override

from pyinfra import logger
from pyinfra.connectors.base import BaseConnector
Expand Down Expand Up @@ -75,9 +75,11 @@ def __getattr__(self, key: str):

raise AttributeError(f"Host `{self.host}` has no data `{key}`")

@override
def __setattr__(self, key: str, value: Any):
self.override_datas[key] = value

@override
def __str__(self):
return str(self.datas)

Expand Down Expand Up @@ -181,9 +183,11 @@ def init(self, state: "State") -> None:
padding_diff = longest_name_len - len(self.name)
self.print_prefix_padding = "".join(" " for _ in range(0, padding_diff))

@override
def __str__(self):
return "{0}".format(self.name)

@override
def __repr__(self):
return "Host({0})".format(self.name)

Expand Down Expand Up @@ -357,7 +361,7 @@ def get_fact(self, name_or_cls, *args, **kwargs):
# Connector proxy
#

def _check_state(self):
def _check_state(self) -> None:
if not self.state:
raise TypeError("Cannot call this function with no state!")

Expand Down Expand Up @@ -399,7 +403,7 @@ def connect(self, reason=None, show_errors: bool = True, raise_exceptions: bool
self.state.trigger_callbacks("host_connect", self)
self.connected = True

def disconnect(self):
def disconnect(self) -> None:
"""
Disconnect from the host using it's configured connector.
"""
Expand Down
3 changes: 2 additions & 1 deletion pyinfra/api/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from types import FunctionType
from typing import TYPE_CHECKING, Any, Callable, Generator, Iterator, Optional, cast

from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, override

import pyinfra
from pyinfra import context, logger
Expand Down Expand Up @@ -52,6 +52,7 @@ def __init__(self, hash, is_change: Optional[bool]):
self._hash = hash
self._maybe_is_change = is_change

@override
def __repr__(self) -> str:
"""
Return Operation object as a string.
Expand Down
2 changes: 1 addition & 1 deletion pyinfra/api/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class StateHostMeta:
ops_no_change = 0
op_hashes: set[str]

def __init__(self):
def __init__(self) -> None:
self.op_hashes = set()


Expand Down
2 changes: 1 addition & 1 deletion pyinfra/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_file(
**arguments: Unpack["ConnectorArguments"],
) -> bool: ...

def check_can_rsync(self):
def check_can_rsync(self) -> None:
raise NotImplementedError("This connector does not support rsync")

def rsync(
Expand Down
9 changes: 7 additions & 2 deletions pyinfra/connectors/chroot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Optional

import click
from typing_extensions import Unpack
from typing_extensions import Unpack, override

from pyinfra import local, logger
from pyinfra.api import QuoteString, StringCommand
Expand All @@ -22,7 +22,7 @@


@memoize
def show_warning():
def show_warning() -> None:
logger.warning("The @chroot connector is in beta!")


Expand All @@ -39,6 +39,7 @@ def __init__(self, state: "State", host: "Host"):
super().__init__(state, host)
self.local = LocalConnector(state, host)

@override
@staticmethod
def make_names_data(name: Optional[str] = None):
if not name:
Expand All @@ -50,6 +51,7 @@ def make_names_data(name: Optional[str] = None):
"chroot_directory": "/{0}".format(name.lstrip("/")),
}, ["@chroot"]

@override
def connect(self) -> None:
self.local.connect()

Expand All @@ -66,6 +68,7 @@ def connect(self) -> None:

self.host.connector_data["chroot_directory"] = chroot_directory

@override
def run_shell_command(
self,
command,
Expand Down Expand Up @@ -97,6 +100,7 @@ def run_shell_command(
**local_arguments,
)

@override
def put_file(
self,
filename_or_io,
Expand Down Expand Up @@ -148,6 +152,7 @@ def put_file(

return status

@override
def get_file(
self,
remote_filename,
Expand Down
9 changes: 7 additions & 2 deletions pyinfra/connectors/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING

import click
from typing_extensions import TypedDict, Unpack
from typing_extensions import TypedDict, Unpack, override

from pyinfra import local, logger
from pyinfra.api import QuoteString, StringCommand
Expand Down Expand Up @@ -115,6 +115,7 @@ def make_names_data(name=None):
["@docker"],
)

@override
def connect(self) -> None:
self.local.connect()

Expand All @@ -127,7 +128,8 @@ def connect(self) -> None:
except PyinfraError:
self.container_id = _start_docker_image(docker_identifier)

def disconnect(self):
@override
def disconnect(self) -> None:
container_id = self.container_id

if self.no_stop:
Expand Down Expand Up @@ -156,6 +158,7 @@ def disconnect(self):
),
)

@override
def run_shell_command(
self,
command: StringCommand,
Expand Down Expand Up @@ -188,6 +191,7 @@ def run_shell_command(
**local_arguments,
)

@override
def put_file(
self,
filename_or_io,
Expand Down Expand Up @@ -245,6 +249,7 @@ def put_file(

return status

@override
def get_file(
self,
remote_filename,
Expand Down
Loading