diff --git a/examples/_mocks.py b/examples/_mocks.py index 04a792c8..6b27fccf 100644 --- a/examples/_mocks.py +++ b/examples/_mocks.py @@ -97,4 +97,4 @@ def create_fake_rig(): computer_name = os.getenv("COMPUTERNAME") os.makedirs(_dir := f"{LIB_CONFIG}/Rig/{computer_name}", exist_ok=True) with open(f"{_dir}/rig1.json", "w", encoding="utf-8") as f: - f.write(RigModel().model_dump_json(indent=2)) + f.write(RigModel(data_directory=r"./local/data").model_dump_json(indent=2)) diff --git a/examples/client_behavior_launcher.py b/examples/client_behavior_launcher.py new file mode 100644 index 00000000..c4481aae --- /dev/null +++ b/examples/client_behavior_launcher.py @@ -0,0 +1,86 @@ +import logging +from pathlib import Path + +from _mocks import ( + LIB_CONFIG, + AindBehaviorSessionModel, + RigModel, + TaskLogicModel, + create_fake_rig, + create_fake_subjects, +) +from pydantic_settings import CliApp + +from clabe import resource_monitor +from clabe.apps import BonsaiApp +from clabe.launcher import Launcher, LauncherCliArgs, experiment +from clabe.pickers import DefaultBehaviorPicker, DefaultBehaviorPickerSettings +from clabe.xml_rpc import XmlRpcClient, XmlRpcClientSettings + +logger = logging.getLogger(__name__) + + +@experiment() +async def client_experiment(launcher: Launcher) -> None: + """Demo experiment showcasing CLABE functionality.""" + picker = DefaultBehaviorPicker( + launcher=launcher, + settings=DefaultBehaviorPickerSettings(config_library_dir=LIB_CONFIG), + experimenter_validator=lambda _: True, + ) + + session = picker.pick_session(AindBehaviorSessionModel) + rig = picker.pick_rig(RigModel) + launcher.register_session(session, rig.data_directory) + trainer_state, task_logic = picker.pick_trainer_state(TaskLogicModel) + + resource_monitor.ResourceMonitor( + constrains=[ + resource_monitor.available_storage_constraint_factory_from_rig(rig, 2e11), + ] + ).run() + + xml_rpc_client = XmlRpcClient(settings=XmlRpcClientSettings(server_url="http://localhost:8000", token="42")) + + bonsai_root = Path(r"C:\git\AllenNeuralDynamics\Aind.Behavior.VrForaging") + session_response = xml_rpc_client.upload_model(session, "session.json") + rig_response = xml_rpc_client.upload_model(rig, "rig.json") + task_logic_response = xml_rpc_client.upload_model(task_logic, "task_logic.json") + assert rig_response.path is not None + assert session_response.path is not None + assert task_logic_response.path is not None + + bonsai_app_result = await xml_rpc_client.run_async( + BonsaiApp( + workflow=bonsai_root / "src/test_deserialization.bonsai", + executable=bonsai_root / "bonsai/bonsai.exe", + additional_externalized_properties={ + "RigPath": rig_response.path, + "SessionPath": session_response.path, + "TaskLogicPath": task_logic_response.path, + }, + ).command + ) + print(bonsai_app_result) + return + + +def main(): + create_fake_subjects() + create_fake_rig() + behavior_cli_args = CliApp.run( + LauncherCliArgs, + cli_args=[ + "--debug-mode", + "--allow-dirty", + "--skip-hardware-validation", + ], + ) + + launcher = Launcher(settings=behavior_cli_args) + launcher.run_experiment(client_experiment) + return None + + +if __name__ == "__main__": + main() diff --git a/src/clabe/apps/__init__.py b/src/clabe/apps/__init__.py index 200fd2aa..334a47df 100644 --- a/src/clabe/apps/__init__.py +++ b/src/clabe/apps/__init__.py @@ -5,8 +5,8 @@ CommandResult, ExecutableApp, Executor, - OutputParser, StdCommand, + _OutputParser, identity_parser, ) from ._bonsai import AindBehaviorServicesBonsaiApp, BonsaiApp @@ -26,7 +26,7 @@ "AsyncExecutor", "Executor", "identity_parser", - "OutputParser", + "_OutputParser", "PythonScriptApp", "ExecutableApp", "StdCommand", diff --git a/src/clabe/apps/_base.py b/src/clabe/apps/_base.py index 794f4bb1..a039d102 100644 --- a/src/clabe/apps/_base.py +++ b/src/clabe/apps/_base.py @@ -111,7 +111,7 @@ class ExecutableApp(Protocol): class MyApp(ExecutableApp): @property def command(self) -> Command: - return Command(cmd="echo hello", output_parser=identity_parser) + return Command(cmd=["echo", "hello"], output_parser=identity_parser) ``` """ @@ -170,7 +170,7 @@ async def run_async(self, command: "Command") -> CommandResult: TOutput = TypeVar("TOutput") -OutputParser: TypeAlias = Callable[[CommandResult], TOutput] +_OutputParser: TypeAlias = Callable[[CommandResult], TOutput] class Command(Generic[TOutput]): @@ -181,14 +181,20 @@ class Command(Generic[TOutput]): Supports both synchronous and asynchronous execution patterns with type-safe output parsing. + Commands are provided as a list of strings, which is consistent with subprocess + and executed directly without shell interpretation. This approach: + - Avoids shell injection vulnerabilities + - Handles arguments with spaces correctly without manual quoting + - Is more portable across platforms + Attributes: - cmd: The command string to execute + cmd: The command to execute as a list of strings result: The result of command execution (available after execution) Example: ```python - # Create a simple command - cmd = Command(cmd="echo hello", output_parser=identity_parser) + # Create a command + cmd = Command(cmd=["python", "-c", "print('hello')"], output_parser=identity_parser) # Execute with a synchronous executor executor = LocalExecutor() @@ -198,24 +204,25 @@ class Command(Generic[TOutput]): def parse_json(result: CommandResult) -> dict: return json.loads(result.stdout) - cmd = Command(cmd="get-data --json", output_parser=parse_json) + cmd = Command(cmd=["get-data", "--json"], output_parser=parse_json) data = cmd.execute(executor) ``` """ - def __init__(self, cmd: str, output_parser: OutputParser[TOutput]) -> None: + def __init__(self, cmd: list[str], output_parser: _OutputParser[TOutput]) -> None: """Initialize the Command instance. + Args: - cmd: The command string to execute + cmd: The command to execute as a list of strings. The first element + is the program to run, followed by its arguments. output_parser: Function to parse the command result into desired output type Example: ```python - # Create a simple command - cmd = Command(cmd="echo hello", output_parser=identity_parser) + cmd = Command(cmd=["echo", "hello"], output_parser=identity_parser) ``` """ - self._cmd = cmd + self._cmd: list[str] = cmd self._output_parser = output_parser self._result: Optional[CommandResult] = None @@ -227,16 +234,30 @@ def result(self) -> CommandResult: return self._result @property - def cmd(self) -> str: - """Get the command string.""" + def cmd(self) -> list[str]: + """Get the command as a list of strings.""" return self._cmd def append_arg(self, args: str | list[str]) -> Self: - """Append an argument to the command.""" + """Append arguments to the command. + + Args: + args: Argument(s) to append. Can be a single string or list of strings. + Empty strings are filtered out. + + Returns: + Self for method chaining. + + Example: + ```python + cmd = Command(cmd=["python"], output_parser=identity_parser) + cmd.append_arg(["-m", "pytest"]) # Results in ["python", "-m", "pytest"] + ``` + """ if isinstance(args, str): args = [args] args = [arg for arg in args if arg] - self._cmd = (self.cmd + f" {' '.join(args)}").strip() + self._cmd = self._cmd + args return self def execute(self, executor: Executor) -> TOutput: @@ -267,9 +288,18 @@ def _parse_output(self, result: CommandResult) -> TOutput: class StdCommand(Command[CommandResult]): - """Standard command that returns the raw CommandResult.""" + """Standard command that returns the raw CommandResult. + + A convenience class that creates a Command with the identity_parser, + returning the raw CommandResult without transformation. + + Example: + ```python + cmd = StdCommand(["echo", "hello"]) + ``` + """ - def __init__(self, cmd: str) -> None: + def __init__(self, cmd: list[str]) -> None: super().__init__(cmd, identity_parser) diff --git a/src/clabe/apps/_bonsai.py b/src/clabe/apps/_bonsai.py index 9320faa9..e4fb0b24 100644 --- a/src/clabe/apps/_bonsai.py +++ b/src/clabe/apps/_bonsai.py @@ -4,7 +4,7 @@ import random from os import PathLike from pathlib import Path -from typing import Dict, Optional +from typing import Dict, List, Optional import pydantic from aind_behavior_services import AindBehaviorRigModel, AindBehaviorSessionModel, AindBehaviorTaskLogicModel @@ -117,13 +117,18 @@ def _build_bonsai_process_command( is_editor_mode: bool = True, is_start_flag: bool = True, additional_properties: Optional[Dict[str, str]] = None, - ) -> str: + ) -> List[str]: """ - Builds a shell command for running a Bonsai workflow via subprocess. + Builds a command list for running a Bonsai workflow via subprocess. - Constructs the complete command string with all necessary flags and properties - for executing a Bonsai workflow. Handles editor mode, start flag, and - externalized properties. + Constructs the complete command as a list of arguments with all necessary + flags and properties for executing a Bonsai workflow. Handles editor mode, + start flag, and externalized properties. + + Using list format is preferred over string format as it: + - Avoids shell injection vulnerabilities + - Handles paths with spaces correctly without manual quoting + - Is more portable across platforms Args: workflow_file: Path to the Bonsai workflow file @@ -133,7 +138,7 @@ def _build_bonsai_process_command( additional_properties: Dictionary of externalized properties to pass. Defaults to None Returns: - str: The complete command string + List[str]: The complete command as a list of arguments Example: ```python @@ -142,19 +147,20 @@ def _build_bonsai_process_command( is_editor_mode=False, additional_properties={"SubjectName": "Mouse123"} ) - # Returns: '"bonsai.exe" "workflow.bonsai" --no-editor -p:"SubjectName"="Mouse123"' + # Returns: ["bonsai.exe", "workflow.bonsai", "--no-editor", "-p:SubjectName=Mouse123"] ``` """ - output_cmd: str = f'"{bonsai_exe}" "{workflow_file}"' + output_cmd: List[str] = [str(bonsai_exe), str(workflow_file)] + if is_editor_mode: if is_start_flag: - output_cmd += " --start" + output_cmd.append("--start") else: - output_cmd += " --no-editor" + output_cmd.append("--no-editor") if additional_properties: for param, value in additional_properties.items(): - output_cmd += f' -p:"{param}"="{value}"' + output_cmd.append(f"-p:{param}={value}") return output_cmd diff --git a/src/clabe/apps/_curriculum.py b/src/clabe/apps/_curriculum.py index 025c8a87..a5efed79 100644 --- a/src/clabe/apps/_curriculum.py +++ b/src/clabe/apps/_curriculum.py @@ -69,7 +69,7 @@ class CurriculumSettings(ServiceSettings): __yml_section__: t.ClassVar[t.Optional[str]] = "curriculum" - script: str = "curriculum run" + script: list[str] = ["curriculum", "run"] project_directory: os.PathLike = Path(".") input_trainer_state: t.Optional[os.PathLike] = None data_directory: t.Optional[os.PathLike] = None @@ -145,18 +145,18 @@ def __init__( raise ValueError("Data directory is not set.") kwargs: dict[str, t.Any] = { # Must use kebab casing - "data-directory": f'"{self._settings.data_directory}"', - "input-trainer-state": f'"{self._settings.input_trainer_state}"', + "data-directory": str(self._settings.data_directory), + "input-trainer-state": str(self._settings.input_trainer_state), } if self._settings.curriculum is not None: - kwargs["curriculum"] = f'"{self._settings.curriculum}"' + kwargs["curriculum"] = str(self._settings.curriculum) python_script_app_kwargs = python_script_app_kwargs or {} self._python_script_app = PythonScriptApp( script=settings.script, project_directory=settings.project_directory, extra_uv_arguments="-q", - additional_arguments=" ".join(f"--{key} {value}" for key, value in kwargs.items()), + additional_arguments=[arg for kv in kwargs.items() for arg in ("--" + kv[0], str(kv[1]))], **python_script_app_kwargs, ) diff --git a/src/clabe/apps/_executors.py b/src/clabe/apps/_executors.py index 2d38f6ec..a001eb2b 100644 --- a/src/clabe/apps/_executors.py +++ b/src/clabe/apps/_executors.py @@ -14,9 +14,14 @@ class LocalExecutor(Executor): and environment variables. Captures both stdout and stderr, and enforces return code checking. + Commands are executed directly without shell interpretation (shell=False), + which avoids shell injection vulnerabilities and handles arguments with + spaces correctly. + Attributes: cwd: Working directory for command execution env: Environment variables for the subprocess + timeout: Maximum execution time in seconds Example: ```python @@ -30,7 +35,7 @@ class LocalExecutor(Executor): executor = LocalExecutor(env={"KEY": "value"}) # Execute a command - cmd = Command(cmd="echo hello", output_parser=identity_parser) + cmd = Command(cmd=["echo", "hello"], output_parser=identity_parser) result = executor.run(cmd) ``` """ @@ -43,6 +48,7 @@ def __init__( Args: cwd: Working directory for command execution env: Environment variables for the subprocess + timeout: Maximum execution time in seconds """ self.cwd = cwd or os.getcwd() @@ -51,17 +57,32 @@ def __init__( def run(self, command: Command[Any]) -> CommandResult: """Execute the command and return the result. + Args: - command: The command to execute + command: The command to execute (as a list of strings) + + Returns: + CommandResult with stdout, stderr, and exit code + + Raises: + CommandError: If the command exits with non-zero exit code + Example: ```python executor = LocalExecutor() - cmd = Command(cmd="echo hello", output_parser=identity_parser) + cmd = Command(cmd=["echo", "hello"], output_parser=identity_parser) result = executor.run(cmd) ``` """ proc = subprocess.run( - command.cmd, cwd=self.cwd, env=self.env, text=True, capture_output=True, check=False, timeout=self.timeout + command.cmd, + cwd=self.cwd, + env=self.env, + text=True, + capture_output=True, + check=False, + timeout=self.timeout, + shell=False, ) result = CommandResult(stdout=proc.stdout, stderr=proc.stderr, exit_code=proc.returncode) result.check_returncode() @@ -72,13 +93,17 @@ class AsyncLocalExecutor(AsyncExecutor): """ Asynchronous executor for running commands on the local system. - Executes commands asynchronously using asyncio.create_subprocess_shell with + Executes commands asynchronously using asyncio subprocess functions with configurable working directory and environment variables. Ideal for long-running processes or when multiple commands need to run concurrently. + Commands are executed directly without shell interpretation, which avoids + shell injection vulnerabilities and handles arguments with spaces correctly. + Attributes: cwd: Working directory for command execution env: Environment variables for the subprocess + timeout: Maximum execution time in seconds Example: ```python @@ -86,13 +111,13 @@ class AsyncLocalExecutor(AsyncExecutor): executor = AsyncLocalExecutor() # Execute a command asynchronously - cmd = Command(cmd="echo hello", output_parser=identity_parser) + cmd = Command(cmd=["echo", "hello"], output_parser=identity_parser) result = await executor.run_async(cmd) # Run multiple commands concurrently executor = AsyncLocalExecutor(cwd="/workdir") - cmd1 = Command(cmd="task1", output_parser=identity_parser) - cmd2 = Command(cmd="task2", output_parser=identity_parser) + cmd1 = Command(cmd=["task1"], output_parser=identity_parser) + cmd2 = Command(cmd=["task2"], output_parser=identity_parser) results = await asyncio.gather( executor.run_async(cmd1), executor.run_async(cmd2) @@ -108,6 +133,7 @@ def __init__( Args: cwd: Working directory for command execution env: Environment variables for the subprocess + timeout: Maximum execution time in seconds """ self.cwd = cwd or os.getcwd() @@ -118,17 +144,24 @@ async def run_async(self, command: Command) -> CommandResult: """Execute the command asynchronously and return the result. Args: - command: The command to execute + command: The command to execute (as a list of strings) + + Returns: + CommandResult with stdout, stderr, and exit code + + Raises: + CommandError: If the command exits with non-zero exit code + TimeoutError: If the command exceeds the timeout Example: ```python executor = AsyncLocalExecutor() - cmd = Command(cmd="echo hello", output_parser=identity_parser) + cmd = Command(cmd=["echo", "hello"], output_parser=identity_parser) result = await executor.run_async(cmd) ``` """ - proc = await asyncio.create_subprocess_shell( - command.cmd, + proc = await asyncio.create_subprocess_exec( + *command.cmd, cwd=self.cwd, env=self.env, stdout=asyncio.subprocess.PIPE, @@ -141,7 +174,7 @@ async def run_async(self, command: Command) -> CommandResult: proc.kill() await proc.wait() assert self.timeout is not None - raise subprocess.TimeoutExpired(command.cmd, self.timeout) from exc + raise subprocess.TimeoutExpired(" ".join(command.cmd), self.timeout) from exc if proc.returncode is None: raise RuntimeError("Process did not complete successfully and returned no return code.") @@ -169,7 +202,7 @@ class _DefaultExecutorMixin: class MyApp(ExecutableApp, _DefaultExecutorMixin): @property def command(self) -> Command: - return Command(cmd="echo hello", output_parser=identity_parser) + return Command(cmd=["echo", "hello"], output_parser=identity_parser) app = MyApp() diff --git a/src/clabe/apps/_python_script.py b/src/clabe/apps/_python_script.py index cb1143b9..af74b009 100644 --- a/src/clabe/apps/_python_script.py +++ b/src/clabe/apps/_python_script.py @@ -53,8 +53,8 @@ class PythonScriptApp(ExecutableApp, _DefaultExecutorMixin): def __init__( self, /, - script: str, - additional_arguments: str = "", + script: str | list[str], + additional_arguments: list[str] | None = None, project_directory: os.PathLike = Path("."), extra_uv_arguments: str = "", optional_toml_dependencies: Optional[list[str]] = None, @@ -70,7 +70,7 @@ def __init__( Args: script: The Python script command to be executed (e.g., "my_module.py" or "my_package run") - additional_arguments: Additional arguments to pass to the script. Defaults to empty string + additional_arguments: Additional arguments to pass to the script. Defaults to None project_directory: The directory where the project resides. Defaults to current directory extra_uv_arguments: Extra arguments to pass to the uv command (e.g., "-q" for quiet). Defaults to empty string optional_toml_dependencies: Additional TOML dependency groups to include (e.g., ["dev", "test"]). Defaults to None @@ -104,25 +104,25 @@ def __init__( ) ``` """ + script = [script] if isinstance(script, str) else script if not skip_validation: self._validate_uv() if not self._has_venv(project_directory): logger.warning("Python environment not found. Creating one...") self.create_environment(project_directory) - self._command = Command[CommandResult](cmd="", output_parser=identity_parser) - - self.command.append_arg( - [ - "uv run", - extra_uv_arguments, - self._make_uv_optional_toml_dependencies(optional_toml_dependencies or []), - self._make_uv_project_directory(project_directory), - "python" if append_python_exe else "", - script, - additional_arguments, - ] - ) + cmd_args: list[str] = ["uv", "run"] + if extra_uv_arguments: + cmd_args.extend(extra_uv_arguments.split()) + cmd_args.extend(self._make_uv_optional_toml_dependencies(optional_toml_dependencies or [])) + cmd_args.extend(self._make_uv_project_directory(project_directory)) + if append_python_exe: + cmd_args.append("python") + cmd_args.extend(script) + if additional_arguments: + cmd_args.extend(additional_arguments) + + self._command = Command[CommandResult](cmd=cmd_args, output_parser=identity_parser) @property def command(self) -> Command[CommandResult]: @@ -186,9 +186,10 @@ def create_environment( # TODO we should probably add a way to run this through our executors logger.info("Creating Python environment with uv venv at %s...", project_directory) run_kwargs = run_kwargs or {} + cmd = ["uv", "venv"] + cls._make_uv_project_directory(project_directory) try: proc = subprocess.run( - f"uv venv {cls._make_uv_project_directory(project_directory)} ", + cmd, shell=False, capture_output=True, text=True, @@ -202,30 +203,29 @@ def create_environment( return proc @staticmethod - def _make_uv_project_directory(project_directory: str | os.PathLike) -> str: + def _make_uv_project_directory(project_directory: str | os.PathLike) -> list[str]: """ Constructs the --directory argument for the uv command. Converts the project directory path to an absolute path and formats it - as a uv command-line argument. + as uv command-line arguments. Args: project_directory: The project directory path Returns: - str: The formatted --directory argument string + list[str]: The formatted --directory arguments as a list Example: ```python - arg = PythonScriptApp._make_uv_project_directory("/my/project") - # Returns: "--directory /my/project" + args = PythonScriptApp._make_uv_project_directory("/my/project") + # Returns: ["--directory", "/my/project"] ``` """ - - return f"--directory {Path(project_directory).resolve()}" + return ["--directory", str(Path(project_directory).resolve())] @staticmethod - def _make_uv_optional_toml_dependencies(optional_toml_dependencies: list[str]) -> str: + def _make_uv_optional_toml_dependencies(optional_toml_dependencies: list[str]) -> list[str]: """ Constructs the --extra arguments for the uv command based on optional TOML dependencies. @@ -236,20 +236,23 @@ def _make_uv_optional_toml_dependencies(optional_toml_dependencies: list[str]) - optional_toml_dependencies: List of optional dependency group names Returns: - str: The formatted --extra arguments string, or empty string if no dependencies + list[str]: The formatted --extra arguments as a list, or empty list if no dependencies Example: ```python args = PythonScriptApp._make_uv_optional_toml_dependencies(["dev", "test"]) - # Returns: "--extra dev --extra test" + # Returns: ["--extra", "dev", "--extra", "test"] args = PythonScriptApp._make_uv_optional_toml_dependencies([]) - # Returns: "" + # Returns: [] ``` """ if not optional_toml_dependencies: - return "" - return " ".join([f"--extra {dep}" for dep in optional_toml_dependencies]) + return [] + result: list[str] = [] + for dep in optional_toml_dependencies: + result.extend(["--extra", dep]) + return result @staticmethod def _validate_uv() -> None: diff --git a/src/clabe/apps/open_ephys.py b/src/clabe/apps/open_ephys.py index e7f98a36..b9b32165 100644 --- a/src/clabe/apps/open_ephys.py +++ b/src/clabe/apps/open_ephys.py @@ -69,7 +69,7 @@ def __init__( self.validate() self._command = Command[CommandResult]( - cmd=f'"{self.executable}" "{self.signal_chain}"', output_parser=identity_parser + cmd=[str(self.executable), str(self.signal_chain)], output_parser=identity_parser ) def validate(self): diff --git a/src/clabe/data_transfer/robocopy.py b/src/clabe/data_transfer/robocopy.py index 2567e76b..4acc3be7 100644 --- a/src/clabe/data_transfer/robocopy.py +++ b/src/clabe/data_transfer/robocopy.py @@ -2,7 +2,7 @@ import shutil from os import PathLike, makedirs from pathlib import Path -from typing import ClassVar, Dict, Optional +from typing import ClassVar, Dict, List, Optional from ..apps import ExecutableApp from ..apps._base import Command, CommandResult, identity_parser @@ -40,15 +40,15 @@ class RobocopyService(DataTransfer[RobocopySettings], _DefaultExecutorMixin, Exe A data transfer service that uses Robocopy to copy files between directories. Provides a wrapper around the Windows Robocopy utility with configurable options - for file copying, logging, and directory management. + for file copying, logging, and directory management. Supports both single + source-destination pairs and multiple mappings via a dictionary. Attributes: - command: The underlying robocopy command that will be executed + command: The robocopy command to be executed Methods: transfer: Executes the Robocopy file transfer validate: Validates the Robocopy service configuration - prompt_input: Prompts the user to confirm the file transfer """ def __init__( @@ -60,26 +60,22 @@ def __init__( Initializes the RobocopyService. Args: - source: The source directory or file to copy - settings: RobocopySettings containing destination and options + source: The source directory/file to copy, or a dict mapping sources to destinations + settings: RobocopySettings containing options Example: ```python - # Initialize with basic parameters: + # Single source-destination: settings = RobocopySettings(destination="D:/destination") service = RobocopyService("C:/source", settings) - # Initialize with logging and move operation: - settings = RobocopySettings( - destination="D:/archive/data", - log="transfer.log", - delete_src=True, - extra_args="/E /COPY:DAT /R:10" - ) - service = RobocopyService("C:/temp/data", settings) + # Multiple source-destination mappings: + service = RobocopyService({ + "C:/data1": "D:/backup1", + "C:/data2": "D:/backup2", + }, settings) ``` """ - self.source = source self._settings = settings self._command = self._build_command() @@ -91,38 +87,45 @@ def command(self) -> Command[CommandResult]: def _build_command(self) -> Command[CommandResult]: """ - Builds the robocopy command based on settings. + Builds a single command that executes all robocopy operations. - Returns: - A Command object ready for execution + For single source-destination, returns a direct robocopy command. + For multiple mappings, chains commands using `cmd /c`. - Raises: - ValueError: If source and destination mapping cannot be resolved + Returns: + A Command object ready for execution. """ - src_dst = self._solve_src_dst_mapping(self.source, self._settings.destination) - - commands = [] - for src, dst in src_dst.items(): - dst = Path(dst) - src = Path(src) + if isinstance(self.source, dict): + src_dst_pairs = [(Path(src), Path(dst)) for src, dst in self.source.items()] + else: + src_dst_pairs = [(Path(self.source), Path(self._settings.destination))] + robocopy_cmds: List[str] = [] + for src, dst in src_dst_pairs: if self._settings.force_dir: makedirs(dst, exist_ok=True) - cmd_parts = ["robocopy", f'"{src.as_posix()}"', f'"{dst.as_posix()}"', self._settings.extra_args] + cmd_parts: List[str] = ["robocopy", f"{src.as_posix()}", f"{dst.as_posix()}"] + + if self._settings.extra_args: + cmd_parts.extend(self._settings.extra_args.split()) if self._settings.log: - cmd_parts.append(f'/LOG:"{Path(dst) / self._settings.log}"') + cmd_parts.append(f"/LOG:{dst / self._settings.log}") if self._settings.delete_src: cmd_parts.append("/MOV") if self._settings.overwrite: cmd_parts.append("/IS") - commands.append(" ".join(cmd_parts)) + robocopy_cmds.append(" ".join(cmd_parts)) + + if len(robocopy_cmds) == 1: + # Single command: split back to list for direct execution + return Command(cmd=robocopy_cmds[0].split(), output_parser=identity_parser) - # TODO there may be a better way to chain with robocopy - full_command = " && ".join(commands) - return Command(cmd=full_command, output_parser=identity_parser) + # Multiple commands: use cmd /c to chain with & (robocopy is Windows-only) + chained = " & ".join(robocopy_cmds) + return Command(cmd=["cmd", "/c", chained], output_parser=identity_parser) def transfer(self) -> None: """ @@ -141,37 +144,6 @@ def transfer(self) -> None: self.run() logger.info("Robocopy transfer completed.") - @staticmethod - def _solve_src_dst_mapping( - source: PathLike | Dict[PathLike, PathLike], destination: Optional[PathLike] - ) -> Dict[PathLike, PathLike]: - """ - Resolves the mapping between source and destination paths. - - Handles both single path mappings and dictionary-based multiple mappings - to create a consistent source-to-destination mapping structure. - - Args: - source: A single source path or a dictionary mapping sources to destinations - destination: The destination path if the source is a single path - - Returns: - A dictionary mapping source paths to destination paths - - Raises: - ValueError: If the input arguments are invalid or inconsistent - """ - if isinstance(source, dict): - if destination: - raise ValueError("Destination should not be provided when source is a dictionary.") - else: - return source - else: - source = Path(source) - if not destination: - raise ValueError("Destination should be provided when source is a single path.") - return {source: Path(destination)} - def validate(self) -> bool: """ Validates whether the Robocopy command is available on the system. diff --git a/src/clabe/xml_rpc/_executor.py b/src/clabe/xml_rpc/_executor.py index 9a86c2db..f29a53ec 100644 --- a/src/clabe/xml_rpc/_executor.py +++ b/src/clabe/xml_rpc/_executor.py @@ -38,7 +38,7 @@ class XmlRpcExecutor: executor = RpcExecutor(client) # Use as synchronous executor - cmd = Command(cmd="echo hello", output_parser=identity_parser) + cmd = Command(cmd=["echo", "hello"], output_parser=identity_parser) result = executor.run(cmd) # Use as asynchronous executor @@ -80,7 +80,7 @@ def run(self, command: Command[Any]) -> CommandResult: Execute the command synchronously via RPC and return the result. Args: - command: The command to execute remotely + command: The command to execute remotely (as a list of strings) Returns: CommandResult with execution output and exit code @@ -92,7 +92,7 @@ def run(self, command: Command[Any]) -> CommandResult: Example: ```python executor = RpcExecutor(settings) - cmd = Command(cmd="python --version", output_parser=identity_parser) + cmd = Command(cmd=["python", "--version"], output_parser=identity_parser) result = executor.run(cmd) print(f"Output: {result.stdout}") ``` @@ -111,7 +111,7 @@ async def run_async(self, command: Command[Any]) -> CommandResult: Execute the command asynchronously via RPC and return the result. Args: - command: The command to execute remotely + command: The command to execute remotely (as a list of strings) Returns: CommandResult with execution output and exit code @@ -123,7 +123,7 @@ async def run_async(self, command: Command[Any]) -> CommandResult: Example: ```python executor = RpcExecutor(settings) - cmd = Command(cmd="sleep 5 && echo done", output_parser=identity_parser) + cmd = Command(cmd=["python", "-c", "print('done')"], output_parser=identity_parser) result = await executor.run_async(cmd) print(f"Output: {result.stdout}") ``` diff --git a/src/clabe/xml_rpc/_server.py b/src/clabe/xml_rpc/_server.py index 39776a81..f67cd223 100644 --- a/src/clabe/xml_rpc/_server.py +++ b/src/clabe/xml_rpc/_server.py @@ -8,11 +8,11 @@ from concurrent.futures import Future, ThreadPoolExecutor from functools import wraps from pathlib import Path -from typing import ClassVar +from typing import ClassVar, Optional from xmlrpc.server import SimpleXMLRPCServer from pydantic import Field, IPvAnyAddress, SecretStr -from pydantic_settings import CliApp +from pydantic_settings import CliApp, CliImplicitFlag from ..constants import TMP_DIR from ..services import ServiceSettings @@ -82,59 +82,80 @@ def __init__(self, settings: XmlRpcServerSettings): server.register_function(self.require_auth(self.delete_file), "delete_file") server.register_function(self.require_auth(self.delete_all_files), "delete_all_files") - logger.info(f"Authentication token: {settings.token.get_secret_value()}") - logger.info(f"XML-RPC server running on {settings.address}:{settings.port}...") - logger.info(f"File transfer directory: {settings.file_transfer_dir.resolve()}") + logger.info("Authentication token: %s", settings.token.get_secret_value()) + logger.info("XML-RPC server running on %s:%s...", settings.address, settings.port) + logger.info("File transfer directory: %s", settings.file_transfer_dir.resolve()) logger.info("Use the token above to authenticate requests") self.server = server def authenticate(self, token: str) -> bool: """Validate token and check expiry""" - return bool(token and token == self.settings.token.get_secret_value()) + is_valid = bool(token and token == self.settings.token.get_secret_value()) + logger.debug("Authentication attempt: %s", "successful" if is_valid else "failed") + return is_valid def require_auth(self, func): """Decorator to require authentication""" @wraps(func) def wrapper(token, *args, **kwargs): + logger.debug("RPC call to '%s' with args=%s...", func.__name__, args[:2] if len(args) > 2 else args) if not self.authenticate(token): + logger.warning("Authentication failed for '%s'", func.__name__) return {"error": "Invalid or expired token"} - return func(*args, **kwargs) + logger.debug("Executing '%s'", func.__name__) + result = func(*args, **kwargs) + logger.debug("'%s' completed successfully", func.__name__) + return result return wrapper def _run_command_sync(self, cmd_args): """Internal method: actually runs the subprocess""" + logger.debug("Executing command: %s", cmd_args) try: proc = subprocess.run(cmd_args, capture_output=True, text=True, check=True) + logger.debug( + "Command completed successfully. Return code: %s, stdout: %s", + proc.returncode, + proc.stdout[:200] + "..." if len(proc.stdout) > 200 else proc.stdout, + ) return {"stdout": proc.stdout, "stderr": proc.stderr, "returncode": proc.returncode} except subprocess.CalledProcessError as e: + logger.error("Command failed with return code: %s, stderr: %s", e.returncode, e.stderr) return {"stdout": e.stdout, "stderr": e.stderr, "returncode": e.returncode} except Exception as e: + logger.error("Command execution error: %s", e) return {"error": str(e)} def submit_command(self, cmd_args): """Submit a command for background execution""" job_id = str(uuid.uuid4()) + logger.debug("Submitting job %s to executor", job_id) future = self.executor.submit(self._run_command_sync, cmd_args) self.jobs[job_id] = future - logger.info(f"Submitted job {job_id}: {cmd_args}") + logger.info("Submitted job %s: %s", job_id, cmd_args) + logger.debug("Active jobs: %s", len(self.jobs)) response = JobSubmissionResponse(success=True, job_id=job_id) return response.model_dump() def get_result(self, job_id): """Fetch the result of a finished command""" + logger.debug("Fetching result for job %s", job_id) if job_id not in self.jobs: + logger.debug("Job %s not found", job_id) return JobStatusResponse( success=False, error="Invalid job_id", job_id=job_id, status=JobStatus.ERROR ).model_dump(mode="json") future = self.jobs[job_id] if not future.done(): + logger.debug("Job %s still running", job_id) return JobStatusResponse(success=True, job_id=job_id, status=JobStatus.RUNNING).model_dump(mode="json") result = future.result() + logger.debug("Job %s completed, cleaning up", job_id) del self.jobs[job_id] # cleanup finished job return JobStatusResponse(success=True, job_id=job_id, status=JobStatus.DONE, result=result).model_dump( mode="json" @@ -143,13 +164,17 @@ def get_result(self, job_id): def is_running(self, job_id): """Check if a job is still running""" if job_id not in self.jobs: + logger.debug("Job %s not found when checking status", job_id) return False - return not self.jobs[job_id].done() + is_running = not self.jobs[job_id].done() + logger.debug("Job %s running status: %s", job_id, is_running) + return is_running def list_jobs(self): """List all running jobs""" running_jobs = [jid for jid, fut in self.jobs.items() if not fut.done()] finished_jobs = [jid for jid, fut in self.jobs.items() if fut.done()] + logger.debug("Listing jobs: %s running, %s finished", len(running_jobs), len(finished_jobs)) return JobListResponse(success=True, running=running_jobs, finished=finished_jobs).model_dump(mode="json") def upload_file(self, filename: str, data_base64: str, overwrite: bool = True) -> dict: @@ -178,7 +203,9 @@ def upload_file(self, filename: str, data_base64: str, overwrite: bool = True) - try: # For now lets force simple filenames to avoid directory traversal safe_filename = Path(filename).name + logger.debug("Upload request for file: %s (safe: %s), overwrite: %s", filename, safe_filename, overwrite) if safe_filename != filename or ".." in filename: + logger.warning("Invalid filename attempted in upload: %s", filename) return FileUploadResponse( success=False, error="Invalid filename - only simple filenames allowed, no paths" ).model_dump() @@ -191,17 +218,20 @@ def upload_file(self, filename: str, data_base64: str, overwrite: bool = True) - ).model_dump() file_data = base64.b64decode(data_base64) + logger.debug("Decoded file data: %s bytes", len(file_data)) if len(file_data) > self.settings.max_file_size: + logger.warning("File too large: %s > %s", len(file_data), self.settings.max_file_size) return FileUploadResponse( success=False, error=f"File too large. Maximum size: {self.settings.max_file_size} bytes " f"({self.settings.max_file_size / (1024 * 1024):.1f} MB)", ).model_dump() + logger.debug("Writing file to: %s", file_path) file_path.write_bytes(file_data) - logger.info(f"File uploaded: {safe_filename} ({len(file_data)} bytes)") + logger.info("File uploaded: %s (%s bytes)", safe_filename, len(file_data)) return FileUploadResponse( success=True, filename=safe_filename, @@ -211,7 +241,7 @@ def upload_file(self, filename: str, data_base64: str, overwrite: bool = True) - ).model_dump() except Exception as e: - logger.error(f"Error uploading file: {e}") + logger.error("Error uploading file: %s", e) return FileUploadResponse(success=False, error=str(e)).model_dump() def download_file(self, filename: str) -> dict: @@ -236,7 +266,9 @@ def download_file(self, filename: str) -> dict: """ try: safe_filename = Path(filename).name + logger.debug("Download request for file: %s (safe: %s)", filename, safe_filename) if safe_filename != filename or ".." in filename: + logger.warning("Invalid filename attempted in download: %s", filename) response = FileDownloadResponse( success=False, error="Invalid filename - only simple filenames allowed, no paths", @@ -272,6 +304,7 @@ def download_file(self, filename: str) -> dict: ) return response.model_dump(mode="json") + logger.debug("Reading file from: %s", file_path) file_data = file_path.read_bytes() # Base64 encode the data for Base64Bytes field @@ -279,7 +312,7 @@ def download_file(self, filename: str) -> dict: base64_encoded_data = base64.b64encode(file_data) - logger.info(f"File downloaded: {safe_filename} ({len(file_data)} bytes)") + logger.info("File downloaded: %s (%s bytes)", safe_filename, len(file_data)) response = FileDownloadResponse( success=True, error=None, @@ -290,7 +323,7 @@ def download_file(self, filename: str) -> dict: return response.model_dump(mode="json") except Exception as e: - logger.error(f"Error downloading file: {e}") + logger.error("Error downloading file: %s", e) response = FileDownloadResponse(success=False, error=str(e), filename=None, size=None, data=None) return response.model_dump(mode="json") @@ -310,6 +343,7 @@ def list_files(self) -> dict: ``` """ try: + logger.debug("Listing files in: %s", self.settings.file_transfer_dir) file_infos = [] for file_path in self.settings.file_transfer_dir.iterdir(): if file_path.is_file(): @@ -323,11 +357,12 @@ def list_files(self) -> dict: file_infos.append(file_info) file_infos.sort(key=lambda x: x.name) + logger.debug("Found %s files", len(file_infos)) response = FileListResponse(success=True, error=None, files=file_infos, count=len(file_infos)) return response.model_dump() except Exception as e: - logger.error(f"Error listing files: {e}") + logger.error("Error listing files: %s", e) response = FileListResponse(success=False, error=str(e), files=[], count=0) return response.model_dump() @@ -349,7 +384,9 @@ def delete_file(self, filename: str) -> dict: """ try: safe_filename = Path(filename).name + logger.debug("Delete request for file: %s (safe: %s)", filename, safe_filename) if safe_filename != filename or ".." in filename: + logger.warning("Invalid filename attempted in delete: %s", filename) response = FileDeleteResponse( success=False, error="Invalid filename - only simple filenames allowed, no paths", filename=None ) @@ -366,12 +403,12 @@ def delete_file(self, filename: str) -> dict: return response.model_dump() file_path.unlink() - logger.info(f"File deleted: {safe_filename}") + logger.info("File deleted: %s", safe_filename) response = FileDeleteResponse(success=True, error=None, filename=safe_filename) return response.model_dump() except Exception as e: - logger.error(f"Error deleting file: {e}") + logger.error("Error deleting file: %s", e) response = FileDeleteResponse(success=False, error=str(e), filename=None) return response.model_dump() @@ -390,6 +427,7 @@ def delete_all_files(self) -> dict: ``` """ try: + logger.debug("Deleting all files from: %s", self.settings.file_transfer_dir) deleted_files = [] deleted_count = 0 @@ -400,9 +438,9 @@ def delete_all_files(self) -> dict: deleted_files.append(file_path.name) deleted_count += 1 except Exception as e: - logger.warning(f"Failed to delete {file_path.name}: {e}") + logger.error("Failed to delete %s: %s", file_path.name, e) - logger.info(f"Deleted all files: {deleted_count} file(s) removed") + logger.info("Deleted all files: %s file(s) removed", deleted_count) response = FileBulkDeleteResponse( success=True, error=None, @@ -412,7 +450,7 @@ def delete_all_files(self) -> dict: return response.model_dump() except Exception as e: - logger.error(f"Error deleting all files: {e}") + logger.error("Error deleting all files: %s", e) response = FileBulkDeleteResponse(success=False, error=str(e), deleted_count=0, deleted_files=[]) return response.model_dump() @@ -420,8 +458,32 @@ def delete_all_files(self) -> dict: class _XmlRpcServerStartCli(XmlRpcServerSettings): """CLI application wrapper for the RPC server.""" + debug: CliImplicitFlag[bool] = Field(default=False, description="Enable debug logging") + dump: Optional[Path] = Field(default=None, description="Path to dump logs to file") + def cli_cmd(self): """Start the RPC server and run it until interrupted.""" + log_level = logging.DEBUG if self.debug else logging.INFO + log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + handlers = [logging.StreamHandler()] + + if self.dump: + self.dump.parent.mkdir(parents=True, exist_ok=True) + file_handler = logging.FileHandler(self.dump, mode="w") + file_handler.setFormatter(logging.Formatter(log_format)) + handlers.append(file_handler) + logging.info("Logging dumped to file: %s", self.dump) + + module_logger = logging.getLogger("clabe.xml_rpc") + module_logger.setLevel(log_level) + for handler in handlers: + handler.setFormatter(logging.Formatter(log_format)) + module_logger.addHandler(handler) + + if self.debug: + logger.debug("Debug logging enabled") + server = XmlRpcServer(settings=self) try: server.server.serve_forever() @@ -430,4 +492,4 @@ def cli_cmd(self): if __name__ == "__main__": - CliApp().run(_XmlRpcServerStartCli) + CliApp().run(_XmlRpcServerStartCli, cli_args=["--token", "42", "--debug"]) diff --git a/tests/apps/test_app.py b/tests/apps/test_app.py index 67d9396f..96865fac 100644 --- a/tests/apps/test_app.py +++ b/tests/apps/test_app.py @@ -27,7 +27,7 @@ def simple_command() -> Command[CommandResult]: """A simple command that echoes text.""" return Command[CommandResult]( - cmd="python -c \"print('hello')\"", + cmd=["python", "-c", "print('hello')"], output_parser=identity_parser, ) @@ -36,7 +36,7 @@ def simple_command() -> Command[CommandResult]: def failing_command() -> Command[CommandResult]: """A command that fails.""" return Command[CommandResult]( - cmd='python -c "import sys; sys.exit(1)"', + cmd=["python", "-c", "import sys; sys.exit(1)"], output_parser=identity_parser, ) @@ -58,7 +58,7 @@ class MockExecutor(Executor): def __init__(self, return_value: CommandResult): self.return_value = return_value - self.executed_commands: list[str] = [] + self.executed_commands: list[list[str]] = [] def run(self, command: Command) -> CommandResult: self.executed_commands.append(command.cmd) @@ -70,7 +70,7 @@ class MockAsyncExecutor(AsyncExecutor): def __init__(self, return_value: CommandResult): self.return_value = return_value - self.executed_commands: list[str] = [] + self.executed_commands: list[list[str]] = [] async def run_async(self, command: Command) -> CommandResult: self.executed_commands.append(command.cmd) @@ -112,38 +112,38 @@ class TestCommand: """Tests for Command class.""" def test_command_initialization(self): - """Test basic command initialization.""" - cmd = Command[str](cmd="echo hello", output_parser=lambda r: r.stdout or "") - assert cmd.cmd == "echo hello" + """Test basic command initialization with list.""" + cmd = Command[str](cmd=["echo", "hello"], output_parser=lambda r: r.stdout or "") + assert cmd.cmd == ["echo", "hello"] def test_command_append_arg_single_string(self): """Test appending a single argument.""" - cmd = Command[CommandResult](cmd="echo", output_parser=identity_parser) + cmd = Command[CommandResult](cmd=["echo"], output_parser=identity_parser) cmd.append_arg("hello") - assert cmd.cmd == "echo hello" + assert cmd.cmd == ["echo", "hello"] def test_command_append_arg_list(self): """Test appending multiple arguments as a list.""" - cmd = Command[CommandResult](cmd="echo", output_parser=identity_parser) + cmd = Command[CommandResult](cmd=["echo"], output_parser=identity_parser) cmd.append_arg(["hello", "world"]) - assert cmd.cmd == "echo hello world" + assert cmd.cmd == ["echo", "hello", "world"] def test_command_append_arg_filters_empty_strings(self): """Test that empty strings are filtered out when appending args.""" - cmd = Command[CommandResult](cmd="echo", output_parser=identity_parser) + cmd = Command[CommandResult](cmd=["echo"], output_parser=identity_parser) cmd.append_arg(["hello", "", "world"]) - assert cmd.cmd == "echo hello world" + assert cmd.cmd == ["echo", "hello", "world"] def test_command_append_arg_chaining(self): """Test that append_arg returns self for chaining.""" - cmd = Command[CommandResult](cmd="echo", output_parser=identity_parser) + cmd = Command[CommandResult](cmd=["echo"], output_parser=identity_parser) result = cmd.append_arg("hello").append_arg("world") assert result is cmd - assert cmd.cmd == "echo hello world" + assert cmd.cmd == ["echo", "hello", "world"] def test_command_result_property_before_execution_raises(self): """Test that accessing result before execution raises RuntimeError.""" - cmd = Command[CommandResult](cmd="echo hello", output_parser=identity_parser) + cmd = Command[CommandResult](cmd=["echo", "hello"], output_parser=identity_parser) with pytest.raises(RuntimeError, match="Command has not been executed yet"): _ = cmd.result @@ -152,11 +152,11 @@ def test_command_execute_with_mock_executor(self): expected_result = CommandResult(stdout="output", stderr="", exit_code=0) executor = MockExecutor(return_value=expected_result) - cmd = Command[CommandResult](cmd="echo hello", output_parser=identity_parser) + cmd = Command[CommandResult](cmd=["echo", "hello"], output_parser=identity_parser) result = cmd.execute(executor) assert result == expected_result - assert "echo hello" in executor.executed_commands + assert ["echo", "hello"] in executor.executed_commands assert cmd.result == expected_result @pytest.mark.asyncio @@ -165,11 +165,11 @@ async def test_command_execute_async_with_mock_executor(self): expected_result = CommandResult(stdout="output", stderr="", exit_code=0) executor = MockAsyncExecutor(return_value=expected_result) - cmd = Command[CommandResult](cmd="echo hello", output_parser=identity_parser) + cmd = Command[CommandResult](cmd=["echo", "hello"], output_parser=identity_parser) result = await cmd.execute_async(executor) assert result == expected_result - assert "echo hello" in executor.executed_commands + assert ["echo", "hello"] in executor.executed_commands assert cmd.result == expected_result def test_command_custom_output_parser(self): @@ -178,7 +178,7 @@ def test_command_custom_output_parser(self): def parse_int(result: CommandResult) -> int: return int(result.stdout.strip()) if result.stdout else 0 - cmd = Command[int](cmd='python -c "print(42)"', output_parser=parse_int) + cmd = Command[int](cmd=["python", "-c", "print(42)"], output_parser=parse_int) executor = MockExecutor(CommandResult(stdout="42\n", stderr="", exit_code=0)) result = cmd.execute(executor) @@ -195,7 +195,7 @@ class TestLocalExecutor: def test_local_executor_runs_simple_command(self, local_executor: LocalExecutor): """Test that LocalExecutor can run a simple command.""" - cmd = Command[CommandResult](cmd="python -c \"print('test')\"", output_parser=identity_parser) + cmd = Command[CommandResult](cmd=["python", "-c", "print('test')"], output_parser=identity_parser) result = cmd.execute(local_executor) assert result.ok is True @@ -204,7 +204,7 @@ def test_local_executor_runs_simple_command(self, local_executor: LocalExecutor) def test_local_executor_captures_stderr(self, local_executor: LocalExecutor): """Test that LocalExecutor captures stderr.""" cmd = Command[CommandResult]( - cmd="python -c \"import sys; sys.stderr.write('error')\"", output_parser=identity_parser + cmd=["python", "-c", "import sys; sys.stderr.write('error')"], output_parser=identity_parser ) result = cmd.execute(local_executor) @@ -220,7 +220,7 @@ def test_local_executor_with_custom_cwd(self, tmp_path: Path): """Test LocalExecutor with a custom working directory.""" executor = LocalExecutor(cwd=tmp_path) cmd = Command[CommandResult]( - cmd='python -c "import os; print(os.getcwd())"', + cmd=["python", "-c", "import os; print(os.getcwd())"], output_parser=identity_parser, ) result = cmd.execute(executor) @@ -236,7 +236,7 @@ class TestAsyncLocalExecutor: @pytest.mark.asyncio async def test_async_executor_runs_simple_command(self, async_local_executor: AsyncLocalExecutor): """Test that AsyncLocalExecutor can run a simple command.""" - cmd = Command[CommandResult](cmd="python -c \"print('async test')\"", output_parser=identity_parser) + cmd = Command[CommandResult](cmd=["python", "-c", "print('async test')"], output_parser=identity_parser) result = await cmd.execute_async(async_local_executor) assert result.ok is True @@ -253,8 +253,8 @@ async def test_async_executor_handles_failing_command( @pytest.mark.asyncio async def test_async_executor_concurrent_execution(self, async_local_executor: AsyncLocalExecutor): """Test running multiple commands concurrently.""" - cmd1 = Command[CommandResult](cmd="python -c \"print('cmd1')\"", output_parser=identity_parser) - cmd2 = Command[CommandResult](cmd="python -c \"print('cmd2')\"", output_parser=identity_parser) + cmd1 = Command[CommandResult](cmd=["python", "-c", "print('cmd1')"], output_parser=identity_parser) + cmd2 = Command[CommandResult](cmd=["python", "-c", "print('cmd2')"], output_parser=identity_parser) results = await asyncio.gather( cmd1.execute_async(async_local_executor), cmd2.execute_async(async_local_executor) @@ -300,8 +300,11 @@ def test_bonsai_app_builds_command_correctly(self, temp_bonsai_files): is_start_flag=True, ) cmd = app.command.cmd - assert str(temp_bonsai_files["exe"]) in cmd - assert str(temp_bonsai_files["workflow"]) in cmd + # Command is now a list + assert isinstance(cmd, list) + cmd_str = " ".join(cmd) + assert str(temp_bonsai_files["exe"]) in cmd_str + assert str(temp_bonsai_files["workflow"]) in cmd_str assert "--start" in cmd def test_bonsai_app_no_editor_mode(self, temp_bonsai_files): @@ -312,6 +315,7 @@ def test_bonsai_app_no_editor_mode(self, temp_bonsai_files): is_editor_mode=False, ) cmd = app.command.cmd + assert isinstance(cmd, list) assert "--no-editor" in cmd assert "--start" not in cmd @@ -323,8 +327,10 @@ def test_bonsai_app_with_additional_properties(self, temp_bonsai_files): additional_externalized_properties={"param1": "value1", "param2": "value2"}, ) cmd = app.command.cmd - assert '-p:"param1"="value1"' in cmd - assert '-p:"param2"="value2"' in cmd + assert isinstance(cmd, list) + # Properties are now in format -p:param1=value1 (without quotes) + assert "-p:param1=value1" in cmd + assert "-p:param2=value2" in cmd def test_bonsai_app_validates_executable_exists(self, tmp_path: Path): """Test that BonsaiApp validation fails if executable doesn't exist.""" @@ -398,8 +404,12 @@ def test_python_script_app_initialization(self, tmp_path: Path): project_directory=tmp_path, ) - assert "uv run" in app.command.cmd - assert "test_script.py" in app.command.cmd + cmd = app.command.cmd + # Command is now a list + assert isinstance(cmd, list) + assert "uv" in cmd + assert "run" in cmd + assert "test_script.py" in cmd def test_python_script_app_with_additional_arguments(self, tmp_path: Path): """Test PythonScriptApp with additional arguments.""" @@ -408,11 +418,12 @@ def test_python_script_app_with_additional_arguments(self, tmp_path: Path): app = PythonScriptApp( script="test.py", - additional_arguments="--verbose --debug", + additional_arguments=["--verbose", "--debug"], project_directory=tmp_path, ) cmd = app.command.cmd + assert isinstance(cmd, list) assert "--verbose" in cmd assert "--debug" in cmd @@ -428,7 +439,11 @@ def test_python_script_app_with_optional_dependencies(self, tmp_path: Path): ) cmd = app.command.cmd - assert "--extra dev" in cmd or "--with dev" in cmd or "dev" in cmd + assert isinstance(cmd, list) + # --extra and the dependency name are now separate list items + assert "--extra" in cmd + assert "dev" in cmd + assert "test" in cmd def test_python_script_app_appends_python_exe(self, tmp_path: Path): """Test PythonScriptApp with append_python_exe=True.""" @@ -442,6 +457,7 @@ def test_python_script_app_appends_python_exe(self, tmp_path: Path): ) cmd = app.command.cmd + assert isinstance(cmd, list) assert "python" in cmd assert "test.py" in cmd @@ -463,7 +479,9 @@ def test_python_script_app_skip_validation(self, tmp_path: Path): skip_validation=True, ) - assert "test.py" in app.command.cmd + cmd = app.command.cmd + assert isinstance(cmd, list) + assert "test.py" in cmd def test_python_script_app_can_be_executed_with_mock_executor(self, tmp_path: Path): """Test that PythonScriptApp can be executed with a mock executor.""" @@ -493,7 +511,7 @@ class TestIntegration: def test_same_command_different_executors(self, tmp_path: Path): """Test that the same command can be run with different executors.""" - cmd = Command[CommandResult](cmd="python -c \"print('hello')\"", output_parser=identity_parser) + cmd = Command[CommandResult](cmd=["python", "-c", "print('hello')"], output_parser=identity_parser) # Run with first executor executor1 = MockExecutor(CommandResult(stdout="output1", stderr="", exit_code=0)) @@ -503,8 +521,8 @@ def test_same_command_different_executors(self, tmp_path: Path): executor2 = MockExecutor(CommandResult(stdout="output2", stderr="", exit_code=0)) cmd._set_result(executor2.run(cmd), override=True) - assert "hello" in executor1.executed_commands[0] - assert "hello" in executor2.executed_commands[0] + assert "hello" in " ".join(executor1.executed_commands[0]) + assert "hello" in " ".join(executor2.executed_commands[0]) def test_app_with_custom_executor(self, tmp_path: Path): """Test using an app with a custom executor instead of the default.""" @@ -526,8 +544,8 @@ def test_app_with_custom_executor(self, tmp_path: Path): async def test_async_and_sync_executors_with_same_command_type(self): """Test that both sync and async executors can work with commands.""" # Note: We use different command instances since they store results - sync_cmd = Command[CommandResult](cmd="python -c \"print('sync')\"", output_parser=identity_parser) - async_cmd = Command[CommandResult](cmd="python -c \"print('async')\"", output_parser=identity_parser) + sync_cmd = Command[CommandResult](cmd=["python", "-c", "print('sync')"], output_parser=identity_parser) + async_cmd = Command[CommandResult](cmd=["python", "-c", "print('async')"], output_parser=identity_parser) sync_executor = MockExecutor(CommandResult(stdout="sync output", stderr="", exit_code=0)) async_executor = MockAsyncExecutor(CommandResult(stdout="async output", stderr="", exit_code=0)) @@ -547,15 +565,15 @@ async def test_async_and_sync_executors_with_same_command_type(self): class TestEdgeCases: """Tests for edge cases and error handling.""" - def test_command_with_empty_string(self): - """Test command with empty cmd string.""" - cmd = Command[CommandResult](cmd="", output_parser=identity_parser) - assert cmd.cmd == "" + def test_command_with_empty_list(self): + """Test command with empty cmd list.""" + cmd = Command[CommandResult](cmd=[], output_parser=identity_parser) + assert cmd.cmd == [] def test_command_result_multiple_override_warning(self): """Test that overriding result logs a warning.""" - cmd = Command[CommandResult](cmd="echo test", output_parser=identity_parser) + cmd = Command[CommandResult](cmd=["echo", "test"], output_parser=identity_parser) result1 = CommandResult(stdout="first", stderr="", exit_code=0) result2 = CommandResult(stdout="second", stderr="", exit_code=0) @@ -585,7 +603,9 @@ def test_python_script_app_with_empty_additional_arguments(self, tmp_path: Path) ) # Command should still be valid - assert "test.py" in app.command.cmd + cmd = app.command.cmd + assert isinstance(cmd, list) + assert "test.py" in cmd @pytest.fixture diff --git a/tests/apps/test_curriculum.py b/tests/apps/test_curriculum.py index 6e881e7a..01f3e82f 100644 --- a/tests/apps/test_curriculum.py +++ b/tests/apps/test_curriculum.py @@ -15,7 +15,7 @@ def curriculum_app() -> CurriculumApp: return CurriculumApp( settings=CurriculumSettings( - script="curriculum run", + script=["curriculum", "run"], input_trainer_state=Path("MockPath"), data_directory="Demo", project_directory=TESTS_ASSETS / "Aind.Behavior.VrForaging.Curricula", diff --git a/tests/data_transfer/test_data_transfer.py b/tests/data_transfer/test_data_transfer.py index 6b36a630..a5727f12 100644 --- a/tests/data_transfer/test_data_transfer.py +++ b/tests/data_transfer/test_data_transfer.py @@ -1,5 +1,7 @@ import os +import shutil import subprocess +import sys import tempfile from datetime import datetime, time from pathlib import Path @@ -19,6 +21,9 @@ from clabe.data_transfer.robocopy import RobocopyService, RobocopySettings from tests import TESTS_ASSETS +_HAS_ROBOCOPY = shutil.which("robocopy") is not None +_IS_WINDOWS = sys.platform == "win32" + @pytest.fixture def source(): @@ -458,6 +463,23 @@ def test_yaml_dump_and_write_read_yaml( assert loaded.get("name") == manifest.name +@pytest.fixture +def robocopy_temp_dirs(tmp_path): + """Create temporary source and destination directories for robocopy tests.""" + source_dir = tmp_path / "source" + dest_dir = tmp_path / "destination" + source_dir.mkdir() + + (source_dir / "file1.txt").write_text("content1") + (source_dir / "file2.txt").write_text("content2") + subdir = source_dir / "subdir" + subdir.mkdir() + (subdir / "file3.txt").write_text("content3") + # Cleanup handled by tmp_path fixture + + yield source_dir, dest_dir + + @pytest.fixture def robocopy_settings(): return RobocopySettings( @@ -488,28 +510,171 @@ def test_initialization(self, robocopy_service, source, robocopy_settings): assert robocopy_service._settings.overwrite assert not robocopy_service._settings.force_dir - def test_transfer(self, robocopy_service): + def test_transfer_mocked(self, robocopy_service): with patch("subprocess.run") as mock_run: mock_run.return_value = MagicMock(stdout="output", stderr="", returncode=0) robocopy_service.transfer() mock_run.assert_called_once() - def test_run(self, robocopy_service): + def test_run_mocked(self, robocopy_service): with patch("subprocess.run") as mock_run: mock_run.return_value = MagicMock(stdout="output", stderr="", returncode=0) result = robocopy_service.run() assert result.ok is True mock_run.assert_called_once() - def test_solve_src_dst_mapping_single_path(self, robocopy_service, source, robocopy_settings): - result = robocopy_service._solve_src_dst_mapping(source, robocopy_settings.destination) - assert result == {Path(source): Path(robocopy_settings.destination)} + def test_command_single_source(self, robocopy_temp_dirs): + """Test command building for single source-destination.""" + source_dir, dest_dir = robocopy_temp_dirs + settings = RobocopySettings(destination=dest_dir, force_dir=False, extra_args="/E") + service = RobocopyService(source=source_dir, settings=settings) + + cmd = service.command.cmd + assert cmd[0] == "robocopy" + assert source_dir.as_posix() in cmd[1] + assert dest_dir.as_posix() in cmd[2] + assert "/E" in cmd + + def test_command_dict_multiple_sources(self, tmp_path): + """Test command building for dict with multiple source-destination pairs.""" + src1 = tmp_path / "src1" + src2 = tmp_path / "src2" + dst1 = tmp_path / "dst1" + dst2 = tmp_path / "dst2" + src1.mkdir() + src2.mkdir() + + settings = RobocopySettings(destination=dst1, force_dir=False, extra_args="/E") + service = RobocopyService(source={src1: dst1, src2: dst2}, settings=settings) + + cmd = service.command.cmd + # Multiple commands should use cmd /c with & to chain + assert cmd[0] == "cmd" + assert cmd[1] == "/c" + assert "&" in cmd[2] + + def test_validate_without_robocopy(self, robocopy_service): + """Test validate method behavior.""" + with patch("clabe.data_transfer.robocopy._HAS_ROBOCOPY", False): + # Reload the service to use the patched value + service = RobocopyService(source=robocopy_service.source, settings=robocopy_service._settings) + # The validate method checks the module-level _HAS_ROBOCOPY + from clabe.data_transfer import robocopy + + original = robocopy._HAS_ROBOCOPY + robocopy._HAS_ROBOCOPY = False + try: + assert not service.validate() + finally: + robocopy._HAS_ROBOCOPY = original + + @pytest.mark.skipif(not _IS_WINDOWS or not _HAS_ROBOCOPY, reason="Requires Windows with robocopy") + def test_transfer_actual_single_source(self, robocopy_temp_dirs): + """Test actual robocopy execution with single source-destination.""" + from clabe.apps import CommandError + + source_dir, dest_dir = robocopy_temp_dirs + settings = RobocopySettings( + destination=dest_dir, + extra_args="/E /DCOPY:DAT /R:1 /W:1", + force_dir=True, + ) + service = RobocopyService(source=source_dir, settings=settings) + + # Robocopy exit codes 0-7 are success, but CommandError is raised for non-zero + try: + service.transfer() + except CommandError as e: + # Exit codes 1-7 are actually success for robocopy + assert e.exit_code < 8, f"Robocopy failed with exit code {e.exit_code}" + + # Verify files were copied + assert (dest_dir / "file1.txt").exists() + assert (dest_dir / "file1.txt").read_text() == "content1" + assert (dest_dir / "file2.txt").exists() + assert (dest_dir / "subdir" / "file3.txt").exists() + + @pytest.mark.skipif(not _IS_WINDOWS or not _HAS_ROBOCOPY, reason="Requires Windows with robocopy") + def test_transfer_actual_dict_sources(self, tmp_path): + """Test actual robocopy execution with dict multiple source-destination pairs.""" + from clabe.apps import CommandError + + # Create two source directories + src1 = tmp_path / "src1" + src2 = tmp_path / "src2" + dst1 = tmp_path / "dst1" + dst2 = tmp_path / "dst2" + src1.mkdir() + src2.mkdir() + + # Create files in each source + (src1 / "from_src1.txt").write_text("source1_content") + (src2 / "from_src2.txt").write_text("source2_content") + + settings = RobocopySettings( + destination=Path("not used_in_dict_case"), + extra_args="/E /DCOPY:DAT /R:1 /W:1", + force_dir=True, + ) + service = RobocopyService(source={src1: dst1, src2: dst2}, settings=settings) + + try: + service.transfer() + except CommandError as e: + assert e.exit_code < 8, f"Robocopy failed with exit code {e.exit_code}" + + # Verify files were copied to respective destinations + assert (dst1 / "from_src1.txt").exists() + assert (dst1 / "from_src1.txt").read_text() == "source1_content" + assert (dst2 / "from_src2.txt").exists() + assert (dst2 / "from_src2.txt").read_text() == "source2_content" + + @pytest.mark.skipif(not _IS_WINDOWS or not _HAS_ROBOCOPY, reason="Requires Windows with robocopy") + def test_transfer_with_delete_src(self, robocopy_temp_dirs): + """Test robocopy with delete_src option (move instead of copy).""" + from clabe.apps import CommandError + + source_dir, dest_dir = robocopy_temp_dirs + settings = RobocopySettings( + destination=dest_dir, + extra_args="/E /R:1 /W:1", + delete_src=True, + force_dir=True, + ) + service = RobocopyService(source=source_dir, settings=settings) - def test_solve_src_dst_mapping_dict(self, robocopy_service, source, robocopy_settings): - source_dict = {source: robocopy_settings.destination} - result = robocopy_service._solve_src_dst_mapping(source_dict, None) - assert result == source_dict + try: + service.transfer() + except CommandError as e: + assert e.exit_code < 8, f"Robocopy failed with exit code {e.exit_code}" - def test_solve_src_dst_mapping_invalid(self, robocopy_service, source): - with pytest.raises(ValueError): - robocopy_service._solve_src_dst_mapping(source, None) + # Files should be moved (deleted from source after copy) + assert (dest_dir / "file1.txt").exists() + assert not (source_dir / "file1.txt").exists() + + @pytest.mark.skipif(not _IS_WINDOWS or not _HAS_ROBOCOPY, reason="Requires Windows with robocopy") + def test_transfer_with_overwrite(self, robocopy_temp_dirs): + """Test robocopy with overwrite option.""" + from clabe.apps import CommandError + + source_dir, dest_dir = robocopy_temp_dirs + dest_dir.mkdir(exist_ok=True) + + # Create existing file in destination with different content + (dest_dir / "file1.txt").write_text("old_content") + + settings = RobocopySettings( + destination=dest_dir, + extra_args="/E /R:1 /W:1", + overwrite=True, + force_dir=True, + ) + service = RobocopyService(source=source_dir, settings=settings) + + try: + service.transfer() + except CommandError as e: + assert e.exit_code < 8, f"Robocopy failed with exit code {e.exit_code}" + + # File should be overwritten with new content + assert (dest_dir / "file1.txt").read_text() == "content1"