|
9 | 9 | import subprocess |
10 | 10 | import unittest |
11 | 11 | from collections import defaultdict |
| 12 | +from concurrent.futures import ProcessPoolExecutor, as_completed |
12 | 13 | from pathlib import Path |
13 | 14 | from typing import TYPE_CHECKING, Callable, Optional |
14 | 15 |
|
@@ -276,192 +277,193 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N |
276 | 277 | return False, function_name, None |
277 | 278 |
|
278 | 279 |
|
279 | | -def process_test_files( |
280 | | - file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig |
281 | | -) -> dict[str, list[FunctionCalledInTest]]: |
| 280 | +def _process_single_test_file( |
| 281 | + test_file: Path, functions: list[TestsInFile], project_root_path: Path, test_framework: str |
| 282 | +) -> tuple[str, list[tuple[str, FunctionCalledInTest]]]: |
282 | 283 | import jedi |
283 | 284 |
|
284 | | - project_root_path = cfg.project_root_path |
285 | | - test_framework = cfg.test_framework |
286 | | - |
287 | | - function_to_test_map = defaultdict(set) |
288 | 285 | jedi_project = jedi.Project(path=project_root_path) |
289 | 286 | goto_cache = {} |
290 | | - tests_cache = TestsCache() |
| 287 | + results = [] |
291 | 288 |
|
292 | | - with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as ( |
293 | | - progress, |
294 | | - task_id, |
295 | | - ): |
296 | | - for test_file, functions in file_to_test_map.items(): |
297 | | - file_hash = TestsCache.compute_file_hash(test_file) |
298 | | - cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash) |
299 | | - if cached_tests: |
300 | | - self_cur = tests_cache.cur |
301 | | - self_cur.execute( |
302 | | - "SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?", |
303 | | - (str(test_file), file_hash), |
304 | | - ) |
305 | | - qualified_names = [row[0] for row in self_cur.fetchall()] |
306 | | - for cached, qualified_name in zip(cached_tests, qualified_names): |
307 | | - function_to_test_map[qualified_name].add(cached) |
308 | | - progress.advance(task_id) |
309 | | - continue |
| 289 | + try: |
| 290 | + script = jedi.Script(path=test_file, project=jedi_project) |
| 291 | + test_functions = set() |
310 | 292 |
|
311 | | - try: |
312 | | - script = jedi.Script(path=test_file, project=jedi_project) |
313 | | - test_functions = set() |
| 293 | + all_names = script.get_names(all_scopes=True, references=True) |
| 294 | + all_defs = script.get_names(all_scopes=True, definitions=True) |
| 295 | + all_names_top = script.get_names(all_scopes=True) |
314 | 296 |
|
315 | | - all_names = script.get_names(all_scopes=True, references=True) |
316 | | - all_defs = script.get_names(all_scopes=True, definitions=True) |
317 | | - all_names_top = script.get_names(all_scopes=True) |
| 297 | + top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} |
| 298 | + top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} |
| 299 | + except Exception as e: |
| 300 | + logger.debug(f"Failed to get jedi script for {test_file}: {e}") |
| 301 | + # tests_cache.close() |
| 302 | + return str(test_file), results |
| 303 | + |
| 304 | + if test_framework == "pytest": |
| 305 | + for function in functions: |
| 306 | + if "[" in function.test_function: |
| 307 | + function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0] |
| 308 | + parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1] |
| 309 | + if function_name in top_level_functions: |
| 310 | + test_functions.add(TestFunction(function_name, function.test_class, parameters, function.test_type)) |
| 311 | + elif function.test_function in top_level_functions: |
| 312 | + test_functions.add(TestFunction(function.test_function, function.test_class, None, function.test_type)) |
| 313 | + elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function): |
| 314 | + base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function) |
| 315 | + if base_name in top_level_functions: |
| 316 | + test_functions.add( |
| 317 | + TestFunction( |
| 318 | + function_name=base_name, |
| 319 | + test_class=function.test_class, |
| 320 | + parameters=function.test_function, |
| 321 | + test_type=function.test_type, |
| 322 | + ) |
| 323 | + ) |
318 | 324 |
|
319 | | - top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} |
320 | | - top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} |
321 | | - except Exception as e: |
322 | | - logger.debug(f"Failed to get jedi script for {test_file}: {e}") |
323 | | - progress.advance(task_id) |
324 | | - continue |
| 325 | + elif test_framework == "unittest": |
| 326 | + functions_to_search = [elem.test_function for elem in functions] |
| 327 | + test_suites = {elem.test_class for elem in functions} |
| 328 | + |
| 329 | + matching_names = test_suites & top_level_classes.keys() |
| 330 | + for matched_name in matching_names: |
| 331 | + for def_name in all_defs: |
| 332 | + if ( |
| 333 | + def_name.type == "function" |
| 334 | + and def_name.full_name is not None |
| 335 | + and f".{matched_name}." in def_name.full_name |
| 336 | + ): |
| 337 | + for function in functions_to_search: |
| 338 | + (is_parameterized, new_function, parameters) = discover_parameters_unittest(function) |
325 | 339 |
|
326 | | - if test_framework == "pytest": |
327 | | - for function in functions: |
328 | | - if "[" in function.test_function: |
329 | | - function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0] |
330 | | - parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1] |
331 | | - if function_name in top_level_functions: |
| 340 | + if is_parameterized and new_function == def_name.name: |
332 | 341 | test_functions.add( |
333 | | - TestFunction(function_name, function.test_class, parameters, function.test_type) |
| 342 | + TestFunction( |
| 343 | + function_name=def_name.name, |
| 344 | + test_class=matched_name, |
| 345 | + parameters=parameters, |
| 346 | + test_type=functions[0].test_type, |
| 347 | + ) |
334 | 348 | ) |
335 | | - elif function.test_function in top_level_functions: |
336 | | - test_functions.add( |
337 | | - TestFunction(function.test_function, function.test_class, None, function.test_type) |
338 | | - ) |
339 | | - elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function): |
340 | | - base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function) |
341 | | - if base_name in top_level_functions: |
| 349 | + elif function == def_name.name: |
342 | 350 | test_functions.add( |
343 | 351 | TestFunction( |
344 | | - function_name=base_name, |
345 | | - test_class=function.test_class, |
346 | | - parameters=function.test_function, |
347 | | - test_type=function.test_type, |
| 352 | + function_name=def_name.name, |
| 353 | + test_class=matched_name, |
| 354 | + parameters=None, |
| 355 | + test_type=functions[0].test_type, |
348 | 356 | ) |
349 | 357 | ) |
350 | 358 |
|
351 | | - elif test_framework == "unittest": |
352 | | - functions_to_search = [elem.test_function for elem in functions] |
353 | | - test_suites = {elem.test_class for elem in functions} |
354 | | - |
355 | | - matching_names = test_suites & top_level_classes.keys() |
356 | | - for matched_name in matching_names: |
357 | | - for def_name in all_defs: |
358 | | - if ( |
359 | | - def_name.type == "function" |
360 | | - and def_name.full_name is not None |
361 | | - and f".{matched_name}." in def_name.full_name |
362 | | - ): |
363 | | - for function in functions_to_search: |
364 | | - (is_parameterized, new_function, parameters) = discover_parameters_unittest(function) |
365 | | - |
366 | | - if is_parameterized and new_function == def_name.name: |
367 | | - test_functions.add( |
368 | | - TestFunction( |
369 | | - function_name=def_name.name, |
370 | | - test_class=matched_name, |
371 | | - parameters=parameters, |
372 | | - test_type=functions[0].test_type, |
373 | | - ) |
374 | | - ) |
375 | | - elif function == def_name.name: |
376 | | - test_functions.add( |
377 | | - TestFunction( |
378 | | - function_name=def_name.name, |
379 | | - test_class=matched_name, |
380 | | - parameters=None, |
381 | | - test_type=functions[0].test_type, |
382 | | - ) |
383 | | - ) |
384 | | - |
385 | | - test_functions_list = list(test_functions) |
386 | | - test_functions_raw = [elem.function_name for elem in test_functions_list] |
387 | | - |
388 | | - test_functions_by_name = defaultdict(list) |
389 | | - for i, func_name in enumerate(test_functions_raw): |
390 | | - test_functions_by_name[func_name].append(i) |
391 | | - |
392 | | - for name in all_names: |
393 | | - if name.full_name is None: |
394 | | - continue |
395 | | - m = FUNCTION_NAME_REGEX.search(name.full_name) |
396 | | - if not m: |
397 | | - continue |
398 | | - |
399 | | - scope = m.group(1) |
400 | | - if scope not in test_functions_by_name: |
401 | | - continue |
402 | | - |
403 | | - cache_key = (name.full_name, name.module_name) |
404 | | - try: |
405 | | - if cache_key in goto_cache: |
406 | | - definition = goto_cache[cache_key] |
407 | | - else: |
408 | | - definition = name.goto(follow_imports=True, follow_builtin_imports=False) |
409 | | - goto_cache[cache_key] = definition |
410 | | - except Exception as e: |
411 | | - logger.debug(str(e)) |
412 | | - continue |
413 | | - |
414 | | - if not definition or definition[0].type != "function": |
415 | | - continue |
416 | | - |
417 | | - definition_path = str(definition[0].module_path) |
418 | | - if ( |
419 | | - definition_path.startswith(str(project_root_path) + os.sep) |
420 | | - and definition[0].module_name != name.module_name |
421 | | - and definition[0].full_name is not None |
422 | | - ): |
423 | | - for index in test_functions_by_name[scope]: |
424 | | - scope_test_function = test_functions_list[index].function_name |
425 | | - scope_test_class = test_functions_list[index].test_class |
426 | | - scope_parameters = test_functions_list[index].parameters |
427 | | - test_type = test_functions_list[index].test_type |
428 | | - |
429 | | - if scope_parameters is not None: |
430 | | - if test_framework == "pytest": |
431 | | - scope_test_function += "[" + scope_parameters + "]" |
432 | | - if test_framework == "unittest": |
433 | | - scope_test_function += "_" + scope_parameters |
434 | | - |
435 | | - full_name_without_module_prefix = definition[0].full_name.replace( |
436 | | - definition[0].module_name + ".", "", 1 |
437 | | - ) |
438 | | - qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" |
439 | | - |
440 | | - tests_cache.insert_test( |
441 | | - file_path=str(test_file), |
442 | | - file_hash=file_hash, |
443 | | - qualified_name_with_modules_from_root=qualified_name_with_modules_from_root, |
444 | | - function_name=scope, |
445 | | - test_class=scope_test_class, |
446 | | - test_function=scope_test_function, |
447 | | - test_type=test_type, |
448 | | - line_number=name.line, |
449 | | - col_number=name.column, |
450 | | - ) |
| 359 | + test_functions_list = list(test_functions) |
| 360 | + test_functions_raw = [elem.function_name for elem in test_functions_list] |
451 | 361 |
|
452 | | - function_to_test_map[qualified_name_with_modules_from_root].add( |
453 | | - FunctionCalledInTest( |
454 | | - tests_in_file=TestsInFile( |
455 | | - test_file=test_file, |
456 | | - test_class=scope_test_class, |
457 | | - test_function=scope_test_function, |
458 | | - test_type=test_type, |
459 | | - ), |
460 | | - position=CodePosition(line_no=name.line, col_no=name.column), |
461 | | - ) |
462 | | - ) |
| 362 | + test_functions_by_name = defaultdict(list) |
| 363 | + for i, func_name in enumerate(test_functions_raw): |
| 364 | + test_functions_by_name[func_name].append(i) |
463 | 365 |
|
464 | | - progress.advance(task_id) |
| 366 | + for name in all_names: |
| 367 | + if name.full_name is None: |
| 368 | + continue |
| 369 | + m = FUNCTION_NAME_REGEX.search(name.full_name) |
| 370 | + if not m: |
| 371 | + continue |
| 372 | + |
| 373 | + scope = m.group(1) |
| 374 | + if scope not in test_functions_by_name: |
| 375 | + continue |
| 376 | + |
| 377 | + cache_key = (name.full_name, name.module_name) |
| 378 | + try: |
| 379 | + if cache_key in goto_cache: |
| 380 | + definition = goto_cache[cache_key] |
| 381 | + else: |
| 382 | + definition = name.goto(follow_imports=True, follow_builtin_imports=False) |
| 383 | + goto_cache[cache_key] = definition |
| 384 | + except Exception as e: |
| 385 | + logger.debug(str(e)) |
| 386 | + continue |
| 387 | + |
| 388 | + if not definition or definition[0].type != "function": |
| 389 | + continue |
| 390 | + |
| 391 | + definition_path = str(definition[0].module_path) |
| 392 | + if ( |
| 393 | + definition_path.startswith(str(project_root_path) + os.sep) |
| 394 | + and definition[0].module_name != name.module_name |
| 395 | + and definition[0].full_name is not None |
| 396 | + ): |
| 397 | + for index in test_functions_by_name[scope]: |
| 398 | + scope_test_function = test_functions_list[index].function_name |
| 399 | + scope_test_class = test_functions_list[index].test_class |
| 400 | + scope_parameters = test_functions_list[index].parameters |
| 401 | + test_type = test_functions_list[index].test_type |
| 402 | + |
| 403 | + if scope_parameters is not None: |
| 404 | + if test_framework == "pytest": |
| 405 | + scope_test_function += "[" + scope_parameters + "]" |
| 406 | + if test_framework == "unittest": |
| 407 | + scope_test_function += "_" + scope_parameters |
| 408 | + |
| 409 | + full_name_without_module_prefix = definition[0].full_name.replace( |
| 410 | + definition[0].module_name + ".", "", 1 |
| 411 | + ) |
| 412 | + qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" |
| 413 | + function_called_in_test = FunctionCalledInTest( |
| 414 | + tests_in_file=TestsInFile( |
| 415 | + test_file=test_file, |
| 416 | + test_class=scope_test_class, |
| 417 | + test_function=scope_test_function, |
| 418 | + test_type=test_type, |
| 419 | + ), |
| 420 | + position=CodePosition(line_no=name.line, col_no=name.column), |
| 421 | + ) |
| 422 | + results.append((qualified_name_with_modules_from_root, function_called_in_test)) |
| 423 | + |
| 424 | + return str(test_file), results |
| 425 | + |
| 426 | + |
| 427 | +def process_test_files( |
| 428 | + file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig |
| 429 | +) -> dict[str, list[FunctionCalledInTest]]: |
| 430 | + project_root_path = cfg.project_root_path |
| 431 | + test_framework = cfg.test_framework |
| 432 | + function_to_test_map = defaultdict(set) |
| 433 | + |
| 434 | + import multiprocessing |
| 435 | + |
| 436 | + max_workers = min(len(file_to_test_map), multiprocessing.cpu_count()) |
| 437 | + max_workers = max(1, max_workers) |
| 438 | + |
| 439 | + with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as ( |
| 440 | + progress, |
| 441 | + task_id, |
| 442 | + ): |
| 443 | + if len(file_to_test_map) == 1 or max_workers == 1: |
| 444 | + for test_file, functions in file_to_test_map.items(): |
| 445 | + _, results = _process_single_test_file(test_file, functions, project_root_path, test_framework) |
| 446 | + for qualified_name, function_called in results: |
| 447 | + function_to_test_map[qualified_name].add(function_called) |
| 448 | + progress.advance(task_id) |
| 449 | + else: |
| 450 | + with ProcessPoolExecutor(max_workers=max_workers) as executor: |
| 451 | + future_to_file = { |
| 452 | + executor.submit( |
| 453 | + _process_single_test_file, test_file, functions, project_root_path, test_framework |
| 454 | + ): test_file |
| 455 | + for test_file, functions in file_to_test_map.items() |
| 456 | + } |
| 457 | + |
| 458 | + for future in as_completed(future_to_file): |
| 459 | + try: |
| 460 | + _, results = future.result() |
| 461 | + for qualified_name, function_called in results: |
| 462 | + function_to_test_map[qualified_name].add(function_called) |
| 463 | + progress.advance(task_id) |
| 464 | + except Exception as e: # noqa: PERF203 |
| 465 | + test_file = future_to_file[future] |
| 466 | + logger.error(f"Error processing test file {test_file}: {e}") |
| 467 | + progress.advance(task_id) |
465 | 468 |
|
466 | | - tests_cache.close() |
467 | 469 | return {function: list(tests) for function, tests in function_to_test_map.items()} |
0 commit comments