|
9 | 9 | import subprocess |
10 | 10 | import unittest |
11 | 11 | from collections import defaultdict |
| 12 | +from concurrent.futures import ProcessPoolExecutor, as_completed |
12 | 13 | from pathlib import Path |
13 | 14 | from typing import TYPE_CHECKING, Callable, Optional |
14 | 15 |
|
|
18 | 19 | from rich.text import Text |
19 | 20 |
|
20 | 21 | from codeflash.cli_cmds.console import console, logger, test_files_progress_bar |
21 | | -from codeflash.code_utils.code_utils import ( |
22 | | - ImportErrorPattern, |
23 | | - custom_addopts, |
24 | | - get_run_tmp_file, |
25 | | - module_name_from_file_path, |
26 | | -) |
| 22 | +from codeflash.code_utils.code_utils import ImportErrorPattern, custom_addopts, get_run_tmp_file |
27 | 23 | from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db |
28 | 24 | from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType |
29 | 25 |
|
@@ -288,157 +284,176 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N |
288 | 284 | return False, function_name, None |
289 | 285 |
|
290 | 286 |
|
291 | | -def process_test_files( |
292 | | - file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig |
293 | | -) -> tuple[dict[str, list[FunctionCalledInTest]], int]: |
| 287 | +def _process_single_test_file( |
| 288 | + test_file: Path, functions: list[TestsInFile], project_root_path: Path, test_framework: str |
| 289 | +) -> tuple[Path, set]: |
294 | 290 | import jedi |
295 | 291 |
|
296 | | - project_root_path = cfg.project_root_path |
297 | | - test_framework = cfg.test_framework |
| 292 | + local_function_to_test_map = set() |
298 | 293 |
|
299 | | - function_to_test_map = defaultdict(set) |
300 | | - jedi_project = jedi.Project(path=project_root_path) |
| 294 | + try: |
| 295 | + jedi_project = jedi.Project(path=project_root_path) |
| 296 | + script = jedi.Script(path=test_file, project=jedi_project) |
| 297 | + test_functions = set() |
301 | 298 |
|
302 | | - with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as ( |
303 | | - progress, |
304 | | - task_id, |
305 | | - ): |
306 | | - for test_file, functions in file_to_test_map.items(): |
307 | | - try: |
308 | | - script = jedi.Script(path=test_file, project=jedi_project) |
309 | | - test_functions = set() |
310 | | - |
311 | | - all_names = script.get_names(all_scopes=True, references=True) |
312 | | - all_defs = script.get_names(all_scopes=True, definitions=True) |
313 | | - all_names_top = script.get_names(all_scopes=True) |
314 | | - |
315 | | - top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} |
316 | | - top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} |
317 | | - except Exception as e: |
318 | | - logger.debug(f"Failed to get jedi script for {test_file}: {e}") |
319 | | - progress.advance(task_id) |
320 | | - continue |
| 299 | + all_names = script.get_names(all_scopes=True, references=True) |
| 300 | + all_defs = script.get_names(all_scopes=True, definitions=True) |
| 301 | + all_names_top = script.get_names(all_scopes=True) |
| 302 | + |
| 303 | + top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} |
| 304 | + top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} |
| 305 | + except Exception as e: |
| 306 | + logger.debug(f"Failed to get jedi script for {test_file}: {e}") |
| 307 | + return test_file, local_function_to_test_map |
| 308 | + |
| 309 | + if test_framework == "pytest": |
| 310 | + for function in functions: |
| 311 | + if "[" in function.test_function: |
| 312 | + function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0] |
| 313 | + parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1] |
| 314 | + if function_name in top_level_functions: |
| 315 | + test_functions.add(TestFunction(function_name, function.test_class, parameters, function.test_type)) |
| 316 | + elif function.test_function in top_level_functions: |
| 317 | + test_functions.add(TestFunction(function.test_function, function.test_class, None, function.test_type)) |
| 318 | + elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function): |
| 319 | + base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function) |
| 320 | + if base_name in top_level_functions: |
| 321 | + test_functions.add( |
| 322 | + TestFunction( |
| 323 | + function_name=base_name, |
| 324 | + test_class=function.test_class, |
| 325 | + parameters=function.test_function, |
| 326 | + test_type=function.test_type, |
| 327 | + ) |
| 328 | + ) |
321 | 329 |
|
322 | | - if test_framework == "pytest": |
323 | | - for function in functions: |
324 | | - if "[" in function.test_function: |
325 | | - function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0] |
326 | | - parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1] |
327 | | - if function_name in top_level_functions: |
| 330 | + elif test_framework == "unittest": |
| 331 | + functions_to_search = [elem.test_function for elem in functions] |
| 332 | + test_suites = {elem.test_class for elem in functions} |
| 333 | + |
| 334 | + matching_names = test_suites & top_level_classes.keys() |
| 335 | + for matched_name in matching_names: |
| 336 | + for def_name in all_defs: |
| 337 | + if ( |
| 338 | + def_name.type == "function" |
| 339 | + and def_name.full_name is not None |
| 340 | + and f".{matched_name}." in def_name.full_name |
| 341 | + ): |
| 342 | + for function in functions_to_search: |
| 343 | + (is_parameterized, new_function, parameters) = discover_parameters_unittest(function) |
| 344 | + |
| 345 | + if is_parameterized and new_function == def_name.name: |
328 | 346 | test_functions.add( |
329 | | - TestFunction(function_name, function.test_class, parameters, function.test_type) |
| 347 | + TestFunction( |
| 348 | + function_name=def_name.name, |
| 349 | + test_class=matched_name, |
| 350 | + parameters=parameters, |
| 351 | + test_type=functions[0].test_type, |
| 352 | + ) |
330 | 353 | ) |
331 | | - elif function.test_function in top_level_functions: |
332 | | - test_functions.add( |
333 | | - TestFunction(function.test_function, function.test_class, None, function.test_type) |
334 | | - ) |
335 | | - elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function): |
336 | | - base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function) |
337 | | - if base_name in top_level_functions: |
| 354 | + elif function == def_name.name: |
338 | 355 | test_functions.add( |
339 | 356 | TestFunction( |
340 | | - function_name=base_name, |
341 | | - test_class=function.test_class, |
342 | | - parameters=function.test_function, |
343 | | - test_type=function.test_type, |
| 357 | + function_name=def_name.name, |
| 358 | + test_class=matched_name, |
| 359 | + parameters=None, |
| 360 | + test_type=functions[0].test_type, |
344 | 361 | ) |
345 | 362 | ) |
346 | 363 |
|
347 | | - elif test_framework == "unittest": |
348 | | - functions_to_search = [elem.test_function for elem in functions] |
349 | | - test_suites = {elem.test_class for elem in functions} |
350 | | - |
351 | | - matching_names = test_suites & top_level_classes.keys() |
352 | | - for matched_name in matching_names: |
353 | | - for def_name in all_defs: |
354 | | - if ( |
355 | | - def_name.type == "function" |
356 | | - and def_name.full_name is not None |
357 | | - and f".{matched_name}." in def_name.full_name |
358 | | - ): |
359 | | - for function in functions_to_search: |
360 | | - (is_parameterized, new_function, parameters) = discover_parameters_unittest(function) |
361 | | - |
362 | | - if is_parameterized and new_function == def_name.name: |
363 | | - test_functions.add( |
364 | | - TestFunction( |
365 | | - function_name=def_name.name, |
366 | | - test_class=matched_name, |
367 | | - parameters=parameters, |
368 | | - test_type=functions[0].test_type, |
369 | | - ) |
370 | | - ) |
371 | | - elif function == def_name.name: |
372 | | - test_functions.add( |
373 | | - TestFunction( |
374 | | - function_name=def_name.name, |
375 | | - test_class=matched_name, |
376 | | - parameters=None, |
377 | | - test_type=functions[0].test_type, |
378 | | - ) |
379 | | - ) |
380 | | - |
381 | | - test_functions_list = list(test_functions) |
382 | | - test_functions_raw = [elem.function_name for elem in test_functions_list] |
383 | | - |
384 | | - test_functions_by_name = defaultdict(list) |
385 | | - for i, func_name in enumerate(test_functions_raw): |
386 | | - test_functions_by_name[func_name].append(i) |
387 | | - |
388 | | - for name in all_names: |
389 | | - if name.full_name is None: |
390 | | - continue |
391 | | - m = FUNCTION_NAME_REGEX.search(name.full_name) |
392 | | - if not m: |
393 | | - continue |
394 | | - |
395 | | - scope = m.group(1) |
396 | | - if scope not in test_functions_by_name: |
397 | | - continue |
398 | | - |
399 | | - try: |
400 | | - definition = name.goto(follow_imports=True, follow_builtin_imports=False) |
401 | | - except Exception as e: |
402 | | - logger.debug(str(e)) |
403 | | - continue |
404 | | - |
405 | | - if not definition or definition[0].type != "function": |
406 | | - continue |
407 | | - |
408 | | - definition_path = str(definition[0].module_path) |
409 | | - if ( |
410 | | - definition_path.startswith(str(project_root_path) + os.sep) |
411 | | - and definition[0].module_name != name.module_name |
412 | | - and definition[0].full_name is not None |
413 | | - ): |
414 | | - for index in test_functions_by_name[scope]: |
415 | | - scope_test_function = test_functions_list[index].function_name |
416 | | - scope_test_class = test_functions_list[index].test_class |
417 | | - scope_parameters = test_functions_list[index].parameters |
418 | | - test_type = test_functions_list[index].test_type |
419 | | - |
420 | | - if scope_parameters is not None: |
421 | | - if test_framework == "pytest": |
422 | | - scope_test_function += "[" + scope_parameters + "]" |
423 | | - if test_framework == "unittest": |
424 | | - scope_test_function += "_" + scope_parameters |
425 | | - |
426 | | - full_name_without_module_prefix = definition[0].full_name.replace( |
427 | | - definition[0].module_name + ".", "", 1 |
428 | | - ) |
429 | | - qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" |
430 | | - |
431 | | - function_to_test_map[qualified_name_with_modules_from_root].add( |
432 | | - FunctionCalledInTest( |
433 | | - tests_in_file=TestsInFile( |
434 | | - test_file=test_file, |
435 | | - test_class=scope_test_class, |
436 | | - test_function=scope_test_function, |
437 | | - test_type=test_type, |
438 | | - ), |
439 | | - position=CodePosition(line_no=name.line, col_no=name.column), |
440 | | - ) |
441 | | - ) |
| 364 | + test_functions_list = list(test_functions) |
| 365 | + test_functions_raw = [elem.function_name for elem in test_functions_list] |
| 366 | + |
| 367 | + test_functions_by_name = defaultdict(list) |
| 368 | + for i, func_name in enumerate(test_functions_raw): |
| 369 | + test_functions_by_name[func_name].append(i) |
| 370 | + |
| 371 | + for name in all_names: |
| 372 | + if name.full_name is None: |
| 373 | + continue |
| 374 | + m = FUNCTION_NAME_REGEX.search(name.full_name) |
| 375 | + if not m: |
| 376 | + continue |
| 377 | + |
| 378 | + scope = m.group(1) |
| 379 | + if scope not in test_functions_by_name: |
| 380 | + continue |
| 381 | + |
| 382 | + try: |
| 383 | + definition = name.goto(follow_imports=True, follow_builtin_imports=False) |
| 384 | + except Exception as e: |
| 385 | + logger.debug(str(e)) |
| 386 | + continue |
| 387 | + |
| 388 | + if not definition or definition[0].type != "function": |
| 389 | + continue |
| 390 | + |
| 391 | + definition_path = str(definition[0].module_path) |
| 392 | + if ( |
| 393 | + definition_path.startswith(str(project_root_path) + os.sep) |
| 394 | + and definition[0].module_name != name.module_name |
| 395 | + and definition[0].full_name is not None |
| 396 | + ): |
| 397 | + for index in test_functions_by_name[scope]: |
| 398 | + scope_test_function = test_functions_list[index].function_name |
| 399 | + scope_test_class = test_functions_list[index].test_class |
| 400 | + scope_parameters = test_functions_list[index].parameters |
| 401 | + test_type = test_functions_list[index].test_type |
| 402 | + |
| 403 | + if scope_parameters is not None: |
| 404 | + if test_framework == "pytest": |
| 405 | + scope_test_function += "[" + scope_parameters + "]" |
| 406 | + if test_framework == "unittest": |
| 407 | + scope_test_function += "_" + scope_parameters |
| 408 | + |
| 409 | + full_name_without_module_prefix = definition[0].full_name.replace( |
| 410 | + definition[0].module_name + ".", "", 1 |
| 411 | + ) |
| 412 | + from codeflash.code_utils.code_utils import module_name_from_file_path |
| 413 | + |
| 414 | + qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" |
| 415 | + |
| 416 | + function_called_in_test = FunctionCalledInTest( |
| 417 | + tests_in_file=TestsInFile( |
| 418 | + test_file=test_file, |
| 419 | + test_class=scope_test_class, |
| 420 | + test_function=scope_test_function, |
| 421 | + test_type=test_type, |
| 422 | + ), |
| 423 | + position=CodePosition(line_no=name.line, col_no=name.column), |
| 424 | + ) |
| 425 | + local_function_to_test_map.add((qualified_name_with_modules_from_root, function_called_in_test)) |
| 426 | + |
| 427 | + return test_file, local_function_to_test_map |
| 428 | + |
| 429 | + |
| 430 | +def process_test_files( |
| 431 | + file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig |
| 432 | +) -> tuple[dict[str, list[FunctionCalledInTest]], int]: |
| 433 | + project_root_path = cfg.project_root_path |
| 434 | + test_framework = cfg.test_framework |
| 435 | + |
| 436 | + function_to_test_map = defaultdict(set) |
| 437 | + |
| 438 | + with ( |
| 439 | + test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as ( |
| 440 | + progress, |
| 441 | + task_id, |
| 442 | + ), |
| 443 | + ProcessPoolExecutor() as executor, |
| 444 | + ): |
| 445 | + future_to_file = { |
| 446 | + executor.submit( |
| 447 | + _process_single_test_file, test_file, functions, project_root_path, test_framework |
| 448 | + ): test_file |
| 449 | + for test_file, functions in file_to_test_map.items() |
| 450 | + } |
| 451 | + |
| 452 | + for future in as_completed(future_to_file): |
| 453 | + _, local_results = future.result() |
| 454 | + |
| 455 | + for qualified_name, function_called_in_test in local_results: |
| 456 | + function_to_test_map[qualified_name].add(function_called_in_test) |
442 | 457 |
|
443 | 458 | progress.advance(task_id) |
444 | 459 |
|
|
0 commit comments