@@ -81,10 +81,7 @@ def run_codeflash_command(
8181) -> bool :
8282 logging .basicConfig (level = logging .INFO )
8383 if config .trace_mode :
84- if config .trace_load == "workload" :
85- return run_trace_test (cwd , config , expected_improvement_pct )
86- if config .trace_load == "testbench" :
87- return run_trace_test2 (cwd , config , expected_improvement_pct )
84+ return run_trace_test (cwd , config , expected_improvement_pct )
8885
8986 path_to_file = cwd / config .file_path
9087 file_contents = path_to_file .read_text ("utf-8" )
@@ -188,55 +185,11 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p
188185 # First command: Run the tracer
189186 test_root = cwd / "tests" / (config .test_framework or "" )
190187 clear_directory (test_root )
191- command = ["python" , "-m" , "codeflash.tracer" , "-o" , "codeflash.trace" , "workload.py" ]
192- process = subprocess .Popen (
193- command , stdout = subprocess .PIPE , stderr = subprocess .STDOUT , text = True , cwd = str (cwd ), env = os .environ .copy ()
194- )
195-
196- output = []
197- for line in process .stdout :
198- logging .info (line .strip ())
199- output .append (line )
200-
201- return_code = process .wait ()
202- stdout = "" .join (output )
203-
204- if return_code != 0 :
205- logging .error (f"Tracer command returned exit code { return_code } " )
206- return False
207-
208- functions_traced = re .search (r"Traced (\d+) function calls successfully and replay test created at - (.*)$" , stdout )
209- if not functions_traced or int (functions_traced .group (1 )) != 3 :
210- logging .error ("Expected 3 traced functions" )
211- return False
212-
213- replay_test_path = pathlib .Path (functions_traced .group (2 ))
214- if not replay_test_path .exists ():
215- logging .error (f"Replay test file missing at { replay_test_path } " )
216- return False
217188
218- # Second command: Run optimization
219- command = ["python" , "../../../codeflash/main.py" , "--replay-test" , str (replay_test_path ), "--no-pr" ]
220- process = subprocess .Popen (
221- command , stdout = subprocess .PIPE , stderr = subprocess .STDOUT , text = True , cwd = str (cwd ), env = os .environ .copy ()
222- )
189+ trace_script = "workload.py" if config .trace_load == "workload" else "testbench.py"
190+ expected_traced_functions = 3 if config .trace_load == "workload" else 5
223191
224- output = []
225- for line in process .stdout :
226- logging .info (line .strip ())
227- output .append (line )
228-
229- return_code = process .wait ()
230- stdout = "" .join (output )
231-
232- return validate_output (stdout , return_code , expected_improvement_pct , config )
233-
234-
235- def run_trace_test2 (cwd : pathlib .Path , config : TestConfig , expected_improvement_pct : int ) -> bool :
236- # First command: Run the tracer
237- test_root = cwd / "tests" / (config .test_framework or "" )
238- clear_directory (test_root )
239- command = ["python" , "-m" , "codeflash.tracer" , "-o" , "codeflash.trace" , "testbench.py" ]
192+ command = ["python" , "-m" , "codeflash.tracer" , "-o" , "codeflash.trace" , trace_script ]
240193 process = subprocess .Popen (
241194 command , stdout = subprocess .PIPE , stderr = subprocess .STDOUT , text = True , cwd = str (cwd ), env = os .environ .copy ()
242195 )
@@ -254,8 +207,8 @@ def run_trace_test2(cwd: pathlib.Path, config: TestConfig, expected_improvement_
254207 return False
255208
256209 functions_traced = re .search (r"Traced (\d+) function calls successfully and replay test created at - (.*)$" , stdout )
257- if not functions_traced or int (functions_traced .group (1 )) != 5 :
258- logging .error ("Expected 5 traced functions" )
210+ if not functions_traced or int (functions_traced .group (1 )) != expected_traced_functions :
211+ logging .error (f "Expected { expected_traced_functions } traced functions" )
259212 return False
260213
261214 replay_test_path = pathlib .Path (functions_traced .group (2 ))
0 commit comments