2424from codeflash .code_utils .code_utils import get_run_tmp_file
2525from codeflash .code_utils .compat import SAFE_SYS_EXECUTABLE
2626from codeflash .code_utils .config_parser import parse_config_file
27+ from codeflash .tracing .pytest_parallelization import pytest_split
2728
2829if TYPE_CHECKING :
2930 from argparse import Namespace
@@ -86,51 +87,97 @@ def main(args: Namespace | None = None) -> ArgumentParser:
8687 config , found_config_path = parse_config_file (parsed_args .codeflash_config )
8788 project_root = project_root_from_module_root (Path (config ["module_root" ]), found_config_path )
8889 if len (unknown_args ) > 0 :
90+ args_dict = {
91+ "functions" : parsed_args .only_functions ,
92+ "disable" : False ,
93+ "project_root" : str (project_root ),
94+ "max_function_count" : parsed_args .max_function_count ,
95+ "timeout" : parsed_args .tracer_timeout ,
96+ "progname" : unknown_args [0 ],
97+ "config" : config ,
98+ "module" : parsed_args .module ,
99+ }
89100 try :
90- result_pickle_file_path = get_run_tmp_file ("tracer_results_file.pkl" )
91- args_dict = {
92- "result_pickle_file_path" : str (result_pickle_file_path ),
93- "output" : str (parsed_args .outfile ),
94- "functions" : parsed_args .only_functions ,
95- "disable" : False ,
96- "project_root" : str (project_root ),
97- "max_function_count" : parsed_args .max_function_count ,
98- "timeout" : parsed_args .tracer_timeout ,
99- "command" : " " .join (sys .argv ),
100- "progname" : unknown_args [0 ],
101- "config" : config ,
102- "module" : parsed_args .module ,
103- }
104-
105- subprocess .run (
106- [
107- SAFE_SYS_EXECUTABLE ,
108- Path (__file__ ).parent / "tracing" / "tracing_new_process.py" ,
109- * sys .argv ,
110- json .dumps (args_dict ),
111- ],
112- cwd = Path .cwd (),
113- check = False ,
114- )
115- try :
116- with result_pickle_file_path .open (mode = "rb" ) as f :
117- data = pickle .load (f )
118- except Exception :
119- console .print ("❌ Failed to trace. Exiting..." )
120- sys .exit (1 )
121- finally :
122- result_pickle_file_path .unlink (missing_ok = True )
123-
124- replay_test_path = data ["replay_test_file_path" ]
125- if not parsed_args .trace_only and replay_test_path is not None :
101+ pytest_splits = []
102+ test_paths = []
103+ replay_test_paths = []
104+ if parsed_args .module and unknown_args [0 ] == "pytest" :
105+ pytest_splits , test_paths = pytest_split (unknown_args [1 :])
106+
107+ if len (pytest_splits ) > 1 :
108+ processes = []
109+ test_paths_set = set (test_paths )
110+ result_pickle_file_paths = []
111+ for i , test_split in enumerate (pytest_splits , start = 1 ):
112+ result_pickle_file_path = get_run_tmp_file (Path (f"tracer_results_file_{ i } .pkl" ))
113+ result_pickle_file_paths .append (result_pickle_file_path )
114+ args_dict ["result_pickle_file_path" ] = str (result_pickle_file_path )
115+ outpath = parsed_args .outfile
116+ outpath = outpath .parent / f"{ outpath .stem } _{ i } { outpath .suffix } "
117+ args_dict ["output" ] = str (outpath )
118+ updated_sys_argv = []
119+ for elem in sys .argv :
120+ if elem in test_paths_set :
121+ updated_sys_argv .extend (test_split )
122+ else :
123+ updated_sys_argv .append (elem )
124+ args_dict ["command" ] = " " .join (updated_sys_argv )
125+ processes .append (
126+ subprocess .Popen (
127+ [
128+ SAFE_SYS_EXECUTABLE ,
129+ Path (__file__ ).parent / "tracing" / "tracing_new_process.py" ,
130+ * updated_sys_argv ,
131+ json .dumps (args_dict ),
132+ ],
133+ cwd = Path .cwd (),
134+ )
135+ )
136+ for process in processes :
137+ process .wait ()
138+ for result_pickle_file_path in result_pickle_file_paths :
139+ try :
140+ with result_pickle_file_path .open (mode = "rb" ) as f :
141+ data = pickle .load (f )
142+ replay_test_paths .append (str (data ["replay_test_file_path" ]))
143+ except Exception :
144+ console .print ("❌ Failed to trace. Exiting..." )
145+ sys .exit (1 )
146+ finally :
147+ result_pickle_file_path .unlink (missing_ok = True )
148+ else :
149+ result_pickle_file_path = get_run_tmp_file (Path ("tracer_results_file.pkl" ))
150+ args_dict ["result_pickle_file_path" ] = str (result_pickle_file_path )
151+ args_dict ["output" ] = str (parsed_args .outfile )
152+ args_dict ["command" ] = " " .join (sys .argv )
153+
154+ subprocess .run (
155+ [
156+ SAFE_SYS_EXECUTABLE ,
157+ Path (__file__ ).parent / "tracing" / "tracing_new_process.py" ,
158+ * sys .argv ,
159+ json .dumps (args_dict ),
160+ ],
161+ cwd = Path .cwd (),
162+ check = False ,
163+ )
164+ try :
165+ with result_pickle_file_path .open (mode = "rb" ) as f :
166+ data = pickle .load (f )
167+ replay_test_paths .append (str (data ["replay_test_file_path" ]))
168+ except Exception :
169+ console .print ("❌ Failed to trace. Exiting..." )
170+ sys .exit (1 )
171+ finally :
172+ result_pickle_file_path .unlink (missing_ok = True )
173+ if not parsed_args .trace_only and replay_test_paths :
126174 from codeflash .cli_cmds .cli import parse_args , process_pyproject_config
127175 from codeflash .cli_cmds .cmd_init import CODEFLASH_LOGO
128176 from codeflash .cli_cmds .console import paneled_text
129177 from codeflash .telemetry import posthog_cf
130178 from codeflash .telemetry .sentry import init_sentry
131179
132- sys .argv = ["codeflash" , "--replay-test" , str (replay_test_path )]
133-
180+ sys .argv = ["codeflash" , "--replay-test" , * replay_test_paths ]
134181 args = parse_args ()
135182 paneled_text (
136183 CODEFLASH_LOGO ,
@@ -150,8 +197,8 @@ def main(args: Namespace | None = None) -> ArgumentParser:
150197 # Delete the trace file and the replay test file if they exist
151198 if outfile :
152199 outfile .unlink (missing_ok = True )
153- if replay_test_path :
154- replay_test_path .unlink (missing_ok = True )
200+ for replay_test_path in replay_test_paths :
201+ Path ( replay_test_path ) .unlink (missing_ok = True )
155202
156203 except BrokenPipeError as exc :
157204 # Prevent "Exception ignored" during interpreter shutdown.
0 commit comments