22
33from __future__ import annotations
44
5- from typing import TYPE_CHECKING
5+ from typing import TYPE_CHECKING , Protocol
66from unittest .mock import patch
77from contextlib import contextmanager
88import copy
9+ import fnmatch
910
1011from e3 .os .process import Run , to_cmd_lines
1112
@@ -49,6 +50,57 @@ def mock_run(config: MockRunConfig | None = None) -> Iterator[MockRun]:
4950 yield run
5051
5152
53+ class ArgumentChecker (Protocol ):
54+ """Argument checker."""
55+
56+ def check (self , arg : str ) -> bool :
57+ """Check an argument.
58+
59+ :param arg: the argument
60+ :return: if the argument is valid
61+ """
62+ ...
63+
64+ def __repr__ (self ) -> str :
65+ """Return a textual representation of the expected argument."""
66+ ...
67+
68+
69+ class GlobChecker (ArgumentChecker ):
70+ """Check an argument against a glob."""
71+
72+ def __init__ (self , pattern : str ) -> None :
73+ """Initialize GlobChecker.
74+
75+ :param pattern: the glob pattern
76+ """
77+ self .pattern = pattern
78+
79+ def check (self , arg : str ) -> bool :
80+ """See ArgumentChecker."""
81+ return fnmatch .fnmatch (arg , self .pattern )
82+
83+ def __repr__ (self ) -> str :
84+ """See ArgumentChecker."""
85+ return self .pattern .__repr__ ()
86+
87+
88+ class SideEffect (Protocol ):
89+ """Function to be called when a mocked command is called."""
90+
91+ def __call__ (
92+ self , result : CommandResult , cmd : list [str ], * args : Any , ** kwargs : Any
93+ ) -> None :
94+ """Run when the mocked command is called.
95+
96+ :param result: the mocked command
97+ :param cmd: actual arguments of the command
98+ :param args: additional arguments for Run
99+ :param kwargs: additional keyword arguments for Run
100+ """
101+ ...
102+
103+
52104class CommandResult :
53105 """Result of a command.
54106
@@ -58,22 +110,25 @@ class CommandResult:
58110
59111 def __init__ (
60112 self ,
61- cmd : list [str ],
113+ cmd : list [str | ArgumentChecker ],
62114 status : int | None = None ,
63115 raw_out : bytes = b"" ,
64116 raw_err : bytes = b"" ,
117+ side_effect : SideEffect | None = None ,
65118 ) -> None :
66119 """Initialize CommandResult.
67120
68121 :param cmd: expected arguments of the command
69122 :param status: status code
70123 :param raw_out: raw output log
71124 :param raw_err: raw error log
125+ :param side_effect: a function to be called when the command is called
72126 """
73127 self .cmd = cmd
74128 self .status = status if status is not None else 0
75129 self .raw_out = raw_out
76130 self .raw_err = raw_err
131+ self .side_effect = side_effect
77132
78133 def check (self , cmd : list [str ]) -> None :
79134 """Check that cmd matches the expected arguments.
@@ -86,10 +141,16 @@ def check(self, cmd: list[str]) -> None:
86141 )
87142
88143 for i , arg in enumerate (cmd ):
89- if arg != self .cmd [i ] and self .cmd [i ] != "*" :
90- raise UnexpectedCommandError (
91- f"unexpected arguments { cmd } , expected { self .cmd } "
92- )
144+ checker = self .cmd [i ]
145+ if isinstance (checker , str ):
146+ if arg == checker or checker == "*" :
147+ continue
148+ elif checker .check (arg ):
149+ continue
150+
151+ raise UnexpectedCommandError (
152+ f"unexpected arguments { cmd } , expected { self .cmd } "
153+ )
93154
94155 def __call__ (self , cmd : list [str ], * args : Any , ** kwargs : Any ) -> None :
95156 """Allow to run code to emulate the command.
@@ -101,7 +162,8 @@ def __call__(self, cmd: list[str], *args: Any, **kwargs: Any) -> None:
101162 :param args: additional arguments for Run
102163 :param kwargs: additional keyword arguments for Run
103164 """
104- pass
165+ if self .side_effect :
166+ self .side_effect (self , cmd , * args , ** kwargs )
105167
106168
107169class MockRun (Run ):
0 commit comments