44
55import threading
66import traceback
7- from functools import wraps
87from typing import Any , Callable , List
98
109
@@ -27,25 +26,12 @@ def __init__(
2726 self .name , self .deps , self .func = name , deps , func
2827
2928
30- def op (name : str , deps = None ):
31- deps = deps or []
32-
33- def decorator (func ):
34- @wraps (func )
35- def _wrapper (* args , ** kwargs ):
36- return func (* args , ** kwargs )
37-
38- _wrapper .op_node = OpNode (name , deps , lambda self , ctx : func (self , ** ctx ))
39- return _wrapper
40-
41- return decorator
42-
43-
4429class Engine :
4530 def __init__ (self , max_workers : int = 4 ):
4631 self .max_workers = max_workers
4732
4833 def run (self , ops : List [OpNode ], ctx : Context ):
34+ self ._validate (ops )
4935 name2op = {operation .name : operation for operation in ops }
5036
5137 # topological sort
@@ -81,7 +67,7 @@ def _exec(n: str):
8167 return
8268 try :
8369 name2op [n ].func (name2op [n ], ctx )
84- except Exception : # pylint: disable=broad-except
70+ except Exception :
8571 exc [n ] = traceback .format_exc ()
8672 done [n ].set ()
8773
@@ -96,6 +82,20 @@ def _exec(n: str):
9682 + "\n " .join (f"---- { op } ----\n { tb } " for op , tb in exc .items ())
9783 )
9884
85+ @staticmethod
86+ def _validate (ops : List [OpNode ]):
87+ name_set = set ()
88+ for op in ops :
89+ if op .name in name_set :
90+ raise ValueError (f"Duplicate operation name: { op .name } " )
91+ name_set .add (op .name )
92+ for op in ops :
93+ for dep in op .deps :
94+ if dep not in name_set :
95+ raise ValueError (
96+ f"Operation { op .name } has unknown dependency: { dep } "
97+ )
98+
9999
100100def collect_ops (config : dict , graph_gen ) -> List [OpNode ]:
101101 """
@@ -106,16 +106,20 @@ def collect_ops(config: dict, graph_gen) -> List[OpNode]:
106106 ops : List [OpNode ] = []
107107 for stage in config ["pipeline" ]:
108108 name = stage ["name" ]
109- method = getattr (graph_gen , name )
110- op_node = method .op_node
111-
112- # if there are runtime dependencies, override them
113- runtime_deps = stage .get ("deps" , op_node .deps )
114- op_node .deps = runtime_deps
109+ method_name = stage .get ("op_key" )
110+ method = getattr (graph_gen , method_name )
111+ deps = stage .get ("deps" , [])
115112
116113 if "params" in stage :
117- op_node .func = lambda self , ctx , m = method , sc = stage : m (sc .get ("params" , {}))
114+
115+ def func (self , ctx , _method = method , _params = stage .get ("params" , {})):
116+ return _method (_params )
117+
118118 else :
119- op_node .func = lambda self , ctx , m = method : m ()
119+
120+ def func (self , ctx , _method = method ):
121+ return _method ()
122+
123+ op_node = OpNode (name = name , deps = deps , func = func )
120124 ops .append (op_node )
121125 return ops
0 commit comments