Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/pyinfra/api/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class ConnectorArguments(TypedDict, total=False):
_retry_delay: Union[int, float]
_retry_until: Optional[Callable[[dict], bool]]

# Temp directory argument
_temp_dir: str


def generate_env(config: "Config", value: dict) -> dict:
env = config.ENV.copy()
Expand Down Expand Up @@ -163,6 +166,10 @@ def generate_env(config: "Config", value: dict) -> dict:
"String or buffer to send to the stdin of any commands.",
default=lambda _: None,
),
"_temp_dir": ArgumentMeta(
"Temporary directory on the remote host for file operations.",
default=lambda config: config.TEMP_DIR,
),
}


Expand Down
2 changes: 2 additions & 0 deletions src/pyinfra/connectors/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ def make_unix_command(
_retries=0,
_retry_delay=0,
_retry_until=None,
# Temp dir config (ignored in command generation, used for temp file path generation)
_temp_dir=None,
) -> StringCommand:
"""
Builds a shell command with various kwargs.
Expand Down
41 changes: 34 additions & 7 deletions src/pyinfra/operations/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,12 @@ def download(

# If we download, always do user/group/mode as SSH user may be different
if download:
# Use explicit temp_dir if provided, otherwise fall back to _temp_dir global argument
effective_temp_dir = temp_dir
if effective_temp_dir is None and host.current_op_global_arguments:
effective_temp_dir = host.current_op_global_arguments.get("_temp_dir")
temp_file = host.get_temp_filename(
dest, temp_directory=str(temp_dir) if temp_dir is not None else None
dest, temp_directory=str(effective_temp_dir) if effective_temp_dir is not None else None
)

curl_args: list[Union[str, StringCommand]] = ["-sSLf"]
Expand Down Expand Up @@ -797,19 +801,32 @@ def get(

remote_file = host.get_fact(File, path=src)

# Use _temp_dir global argument if provided
temp_dir = None
if host.current_op_global_arguments:
temp_dir = host.current_op_global_arguments.get("_temp_dir")

# No remote file, so assume exists and download it "blind"
if not remote_file or force:
yield FileDownloadCommand(src, dest, remote_temp_filename=host.get_temp_filename(dest))
yield FileDownloadCommand(
src, dest, remote_temp_filename=host.get_temp_filename(dest, temp_directory=temp_dir)
)

# No local file, so always download
elif not os.path.exists(dest):
yield FileDownloadCommand(src, dest, remote_temp_filename=host.get_temp_filename(dest))
yield FileDownloadCommand(
src, dest, remote_temp_filename=host.get_temp_filename(dest, temp_directory=temp_dir)
)

# Remote file exists - check if it matches our local
else:
# Check hash sum, download if needed
if not _file_equal(dest, src):
yield FileDownloadCommand(src, dest, remote_temp_filename=host.get_temp_filename(dest))
yield FileDownloadCommand(
src,
dest,
remote_temp_filename=host.get_temp_filename(dest, temp_directory=temp_dir),
)
else:
host.noop("file {0} has already been downloaded".format(dest))

Expand Down Expand Up @@ -998,6 +1015,11 @@ def put(
if create_remote_dir:
yield from _create_remote_dir(dest, user, group)

# Use _temp_dir global argument if provided
temp_dir = None
if host.current_op_global_arguments:
temp_dir = host.current_op_global_arguments.get("_temp_dir")

# No remote file, always upload and user/group/mode if supplied
if not remote_file or force:
if state.config.DIFF:
Expand All @@ -1013,7 +1035,7 @@ def put(
yield FileUploadCommand(
local_file,
dest,
remote_temp_filename=host.get_temp_filename(dest),
remote_temp_filename=host.get_temp_filename(dest, temp_directory=temp_dir),
)

if user or group:
Expand Down Expand Up @@ -1060,7 +1082,7 @@ def put(
yield FileUploadCommand(
local_file,
dest,
remote_temp_filename=host.get_temp_filename(dest),
remote_temp_filename=host.get_temp_filename(dest, temp_directory=temp_dir),
)

if user or group:
Expand Down Expand Up @@ -1803,7 +1825,12 @@ def block(
current = host.get_fact(Block, path=path, marker=marker, begin=begin, end=end)
cmd = None

tmp_dir = host.get_temp_dir_config()
# Use _temp_dir global argument if provided, otherwise fall back to config
tmp_dir = None
if host.current_op_global_arguments:
tmp_dir = host.current_op_global_arguments.get("_temp_dir")
if not tmp_dir:
tmp_dir = host.get_temp_dir_config()

# standard awk doesn't have an "in-place edit" option so we write to a tempfile and
# if edits were successful move to dest i.e. we do: <out_prep> ... do some work ... <real_out>
Expand Down