11import argparse
22import importlib
33import traceback
4+ from collections .abc import Callable
45from concurrent .futures import Future
6+ from typing import Any
57
68from flask import Flask , jsonify , request
79
2022# Global engine instance - must be TrainEngine or InferenceEngine
2123_engine : TrainEngine | InferenceEngine | None = None
2224
25+
26+ def _handle_submit (
27+ engine : InferenceEngine ,
28+ args : list [Any ],
29+ kwargs : dict [str , Any ],
30+ ) -> tuple [list [Any ], dict [str , Any ]]:
31+ workflow_path = kwargs ["workflow_path" ]
32+ workflow_kwargs = kwargs ["workflow_kwargs" ]
33+ episode_data = kwargs ["data" ]
34+ should_accept_path = kwargs ["should_accept_path" ]
35+
36+ episode_data = deserialize_value (episode_data )
37+
38+ module_path , class_name = workflow_path .rsplit ("." , 1 )
39+ module = importlib .import_module (module_path )
40+ workflow_class = getattr (module , class_name )
41+ logger .info (f"Imported workflow class: { workflow_path } " )
42+
43+ workflow_kwargs = deserialize_value (workflow_kwargs )
44+ workflow = workflow_class (** workflow_kwargs )
45+ logger .info (f"Workflow '{ workflow_path } ' instantiated successfully" )
46+
47+ should_accept = None
48+ if should_accept_path is not None :
49+ module_path , fn_name = should_accept_path .rsplit ("." , 1 )
50+ module = importlib .import_module (module_path )
51+ should_accept = getattr (module , fn_name )
52+ logger .info (f"Imported filtering function: { should_accept_path } " )
53+
54+ new_args : list [Any ] = []
55+ new_kwargs : dict [str , Any ] = dict (
56+ data = episode_data ,
57+ workflow = workflow ,
58+ should_accept = should_accept ,
59+ )
60+ return new_args , new_kwargs
61+
62+
63+ _METHOD_HANDLERS : dict [
64+ str ,
65+ Callable [
66+ [InferenceEngine , list [Any ], dict [str , Any ]], tuple [list [Any ], dict [str , Any ]]
67+ ],
68+ ] = {
69+ "submit" : _handle_submit ,
70+ }
71+
2372# Create Flask app
2473app = Flask (__name__ )
2574
@@ -207,42 +256,12 @@ def call_engine_method():
207256 500 ,
208257 )
209258
210- # Special case for `submit` on inference engines
259+ # Special handling for some methods on inference engines
211260 try :
212- if method_name == "submit" and isinstance (_engine , InferenceEngine ):
213- workflow_path = kwargs ["workflow_path" ]
214- workflow_kwargs = kwargs ["workflow_kwargs" ]
215- episode_data = kwargs ["data" ]
216- should_accept_path = kwargs ["should_accept_path" ]
217-
218- # Deserialize episode_data (may contain tensors)
219- episode_data = deserialize_value (episode_data )
220-
221- # Dynamic import workflow
222- module_path , class_name = workflow_path .rsplit ("." , 1 )
223- module = importlib .import_module (module_path )
224- workflow_class = getattr (module , class_name )
225- logger .info (f"Imported workflow class: { workflow_path } " )
226-
227- # Instantiate workflow
228- workflow_kwargs = deserialize_value (workflow_kwargs )
229- workflow = workflow_class (** workflow_kwargs )
230- logger .info (f"Workflow '{ workflow_path } ' instantiated successfully" )
231-
232- should_accept = None
233- if should_accept_path is not None :
234- # Dynamic import filtering function
235- module_path , fn_name = should_accept_path .rsplit ("." , 1 )
236- module = importlib .import_module (module_path )
237- should_accept = getattr (module , fn_name )
238- logger .info (f"Imported filtering function: { should_accept_path } " )
239-
240- args = []
241- kwargs = dict (
242- data = episode_data ,
243- workflow = workflow ,
244- should_accept = should_accept ,
245- )
261+ if isinstance (_engine , InferenceEngine ):
262+ handler = _METHOD_HANDLERS .get (method_name )
263+ if handler is not None :
264+ args , kwargs = handler (_engine , args , kwargs )
246265 except Exception as e :
247266 logger .error (
248267 f"Workflow data conversion failed: { e } \n { traceback .format_exc ()} "
@@ -261,7 +280,9 @@ def call_engine_method():
261280
262281 # HACK: handle update weights future
263282 if isinstance (result , Future ):
283+ logger .info ("Waiting for update weights future" )
264284 result = result .result ()
285+ logger .info ("Update weights future done" )
265286
266287 # Serialize result (convert tensors to SerializedTensor dicts)
267288 serialized_result = serialize_value (result )
@@ -296,7 +317,11 @@ def export_stats():
296317 return jsonify ({"error" : "Engine not initialized" }), 503
297318
298319 # TrainEngine: reduce stats across data_parallel_group
299- assert isinstance (_engine , TrainEngine )
320+ if not isinstance (_engine , TrainEngine ):
321+ return (
322+ jsonify ({"error" : "/export_stats is only available for TrainEngine" }),
323+ 400 ,
324+ )
300325 result = stats_tracker .export (reduce_group = _engine .data_parallel_group )
301326 return jsonify ({"status" : "success" , "result" : result })
302327
0 commit comments