@@ -138,6 +138,32 @@ def __eq__(self, other: ty.Any) -> bool:
138138OutputsType = ty .TypeVar ("OutputType" , bound = TaskOutputs )
139139
140140
141+ def donothing (* args : ty .Any , ** kwargs : ty .Any ) -> None :
142+ return None
143+
144+
145+ @attrs .define (kw_only = True )
146+ class TaskHooks :
147+ """Callable task hooks."""
148+
149+ pre_run_task : ty .Callable = attrs .field (
150+ default = donothing , converter = default_if_none (donothing )
151+ )
152+ post_run_task : ty .Callable = attrs .field (
153+ default = donothing , converter = default_if_none (donothing )
154+ )
155+ pre_run : ty .Callable = attrs .field (
156+ default = donothing , converter = default_if_none (donothing )
157+ )
158+ post_run : ty .Callable = attrs .field (
159+ default = donothing , converter = default_if_none (donothing )
160+ )
161+
162+ def reset (self ):
163+ for val in ["pre_run_task" , "post_run_task" , "pre_run" , "post_run" ]:
164+ setattr (self , val , donothing )
165+
166+
141167@attrs .define (kw_only = True , auto_attribs = False , eq = False )
142168class TaskDef (ty .Generic [OutputsType ]):
143169 """Base class for all task definitions"""
@@ -161,10 +187,7 @@ def __call__(
161187 messengers : ty .Iterable [Messenger ] | None = None ,
162188 messenger_args : dict [str , ty .Any ] | None = None ,
163189 name : str | None = None ,
164- pre_run : ty .Callable ["Task" , None ] | None = None ,
165- post_run : ty .Callable ["Task" , None ] | None = None ,
166- pre_run_task : ty .Callable ["Task" , None ] | None = None ,
167- post_run_task : ty .Callable ["Task" , None ] | None = None ,
190+ hooks : TaskHooks | None = None ,
168191 ** kwargs : ty .Any ,
169192 ) -> OutputsType :
170193 """Create a task from this definition and execute it to produce a result.
@@ -220,10 +243,7 @@ def __call__(
220243 result = sub (
221244 self ,
222245 name = name ,
223- pre_run = pre_run ,
224- post_run = post_run ,
225- pre_run_task = pre_run_task ,
226- post_run_task = post_run_task ,
246+ hooks = hooks ,
227247 )
228248 except TypeError as e :
229249 # Catch any inadvertent passing of task definition parameters to the
@@ -1254,32 +1274,6 @@ def _generated_output_names(self, stdout: str, stderr: str):
12541274 DEFAULT_COPY_COLLATION = FileSet .CopyCollation .adjacent
12551275
12561276
1257- def donothing (* args : ty .Any , ** kwargs : ty .Any ) -> None :
1258- return None
1259-
1260-
1261- @attrs .define (kw_only = True )
1262- class TaskHook :
1263- """Callable task hooks."""
1264-
1265- pre_run_task : ty .Callable = attrs .field (
1266- default = donothing , converter = default_if_none (donothing )
1267- )
1268- post_run_task : ty .Callable = attrs .field (
1269- default = donothing , converter = default_if_none (donothing )
1270- )
1271- pre_run : ty .Callable = attrs .field (
1272- default = donothing , converter = default_if_none (donothing )
1273- )
1274- post_run : ty .Callable = attrs .field (
1275- default = donothing , converter = default_if_none (donothing )
1276- )
1277-
1278- def reset (self ):
1279- for val in ["pre_run_task" , "post_run_task" , "pre_run" , "post_run" ]:
1280- setattr (self , val , donothing )
1281-
1282-
12831277def split_cmd (cmd : str | None ):
12841278 """Splits a shell command line into separate arguments respecting quotes
12851279
0 commit comments