2424from codeflash .code_utils .code_utils import get_run_tmp_file
2525from codeflash .code_utils .compat import SAFE_SYS_EXECUTABLE
2626from codeflash .code_utils .config_parser import parse_config_file
27+ from codeflash .tracing .pytest_parallelization import pytest_split
2728
2829if TYPE_CHECKING :
2930 from argparse import Namespace
@@ -86,51 +87,102 @@ def main(args: Namespace | None = None) -> ArgumentParser:
8687 config , found_config_path = parse_config_file (parsed_args .codeflash_config )
8788 project_root = project_root_from_module_root (Path (config ["module_root" ]), found_config_path )
8889 if len (unknown_args ) > 0 :
90+ args_dict = {
91+ "functions" : parsed_args .only_functions ,
92+ "disable" : False ,
93+ "project_root" : str (project_root ),
94+ "max_function_count" : parsed_args .max_function_count ,
95+ "timeout" : parsed_args .tracer_timeout ,
96+ "progname" : unknown_args [0 ],
97+ "config" : config ,
98+ "module" : parsed_args .module ,
99+ }
89100 try :
90- result_pickle_file_path = get_run_tmp_file ("tracer_results_file.pkl" )
91- args_dict = {
92- "result_pickle_file_path" : str (result_pickle_file_path ),
93- "output" : str (parsed_args .outfile ),
94- "functions" : parsed_args .only_functions ,
95- "disable" : False ,
96- "project_root" : str (project_root ),
97- "max_function_count" : parsed_args .max_function_count ,
98- "timeout" : parsed_args .tracer_timeout ,
99- "command" : " " .join (sys .argv ),
100- "progname" : unknown_args [0 ],
101- "config" : config ,
102- "module" : parsed_args .module ,
103- }
104-
105- subprocess .run (
106- [
107- SAFE_SYS_EXECUTABLE ,
108- Path (__file__ ).parent / "tracing" / "tracing_new_process.py" ,
109- * sys .argv ,
110- json .dumps (args_dict ),
111- ],
112- cwd = Path .cwd (),
113- check = False ,
114- )
115- try :
116- with result_pickle_file_path .open (mode = "rb" ) as f :
117- data = pickle .load (f )
118- except Exception :
119- console .print ("❌ Failed to trace. Exiting..." )
120- sys .exit (1 )
121- finally :
122- result_pickle_file_path .unlink (missing_ok = True )
123-
124- replay_test_path = data ["replay_test_file_path" ]
125- if not parsed_args .trace_only and replay_test_path is not None :
101+ pytest_splits = []
102+ test_paths = []
103+ replay_test_paths = []
104+ if parsed_args .module and unknown_args [0 ] == "pytest" :
105+ pytest_splits , test_paths = pytest_split (unknown_args [1 :])
106+ print (pytest_splits )
107+
108+ if len (pytest_splits ) > 1 :
109+ processes = []
110+ test_paths_set = set (test_paths )
111+ result_pickle_file_paths = []
112+ for i , test_split in enumerate (pytest_splits , start = 1 ):
113+ result_pickle_file_path = get_run_tmp_file (f"tracer_results_file_{ i } .pkl" )
114+ result_pickle_file_paths .append (result_pickle_file_path )
115+ args_dict ["result_pickle_file_path" ] = str (result_pickle_file_path )
116+ outpath = parsed_args .outfile
117+ outpath = outpath .parent / f"{ outpath .stem } _{ i } { outpath .suffix } "
118+ args_dict ["output" ] = str (outpath )
119+ added_paths = False
120+ updated_sys_argv = []
121+ for elem in sys .argv :
122+ if elem in test_paths_set :
123+ if not added_paths :
124+ updated_sys_argv .extend (test_split )
125+ else :
126+ updated_sys_argv .append (elem )
127+ args_dict ["command" ] = " " .join (updated_sys_argv )
128+ processes .append (
129+ subprocess .Popen (
130+ [
131+ SAFE_SYS_EXECUTABLE ,
132+ Path (__file__ ).parent / "tracing" / "tracing_new_process.py" ,
133+ * updated_sys_argv ,
134+ json .dumps (args_dict ),
135+ ],
136+ cwd = Path .cwd (),
137+ )
138+ )
139+ for process in processes :
140+ process .wait ()
141+ for result_pickle_file_path in result_pickle_file_paths :
142+ try :
143+ with result_pickle_file_path .open (mode = "rb" ) as f :
144+ data = pickle .load (f )
145+ replay_test_paths .append (str (data ["replay_test_file_path" ]))
146+ except Exception :
147+ console .print ("❌ Failed to trace. Exiting..." )
148+ sys .exit (1 )
149+ finally :
150+ result_pickle_file_path .unlink (missing_ok = True )
151+ else :
152+ result_pickle_file_path = get_run_tmp_file ("tracer_results_file.pkl" )
153+ args_dict ["result_pickle_file_path" ] = str (result_pickle_file_path )
154+ args_dict ["output" ] = str (parsed_args .outfile )
155+ args_dict ["command" ] = " " .join (sys .argv )
156+
157+ subprocess .run (
158+ [
159+ SAFE_SYS_EXECUTABLE ,
160+ Path (__file__ ).parent / "tracing" / "tracing_new_process.py" ,
161+ * sys .argv ,
162+ json .dumps (args_dict ),
163+ ],
164+ cwd = Path .cwd (),
165+ check = False ,
166+ )
167+ try :
168+ with result_pickle_file_path .open (mode = "rb" ) as f :
169+ data = pickle .load (f )
170+ replay_test_paths .append (str (data ["replay_test_file_path" ]))
171+ except Exception :
172+ console .print ("❌ Failed to trace. Exiting..." )
173+ sys .exit (1 )
174+ finally :
175+ result_pickle_file_path .unlink (missing_ok = True )
176+
177+ if not parsed_args .trace_only and replay_test_paths :
126178 from codeflash .cli_cmds .cli import parse_args , process_pyproject_config
127179 from codeflash .cli_cmds .cmd_init import CODEFLASH_LOGO
128180 from codeflash .cli_cmds .console import paneled_text
129181 from codeflash .telemetry import posthog_cf
130182 from codeflash .telemetry .sentry import init_sentry
131183
132- sys .argv = ["codeflash" , "--replay-test" , str ( replay_test_path ) ]
133-
184+ sys .argv = ["codeflash" , "--replay-test" , * replay_test_paths ]
185+ print ( sys . argv )
134186 args = parse_args ()
135187 paneled_text (
136188 CODEFLASH_LOGO ,
@@ -150,7 +202,7 @@ def main(args: Namespace | None = None) -> ArgumentParser:
150202 # Delete the trace file and the replay test file if they exist
151203 if outfile :
152204 outfile .unlink (missing_ok = True )
153- if replay_test_path :
205+ for replay_test_path in replay_test_paths :
154206 replay_test_path .unlink (missing_ok = True )
155207
156208 except BrokenPipeError as exc :
0 commit comments