|
12 | 12 | from pathlib import Path |
13 | 13 | from typing import TYPE_CHECKING, Callable, Optional |
14 | 14 |
|
| 15 | +import jedi |
15 | 16 | import pytest |
16 | 17 | from pydantic.dataclasses import dataclass |
17 | 18 | from rich.panel import Panel |
@@ -288,191 +289,195 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N |
288 | 289 | return False, function_name, None |
289 | 290 |
|
290 | 291 |
|
291 | | -def process_test_files( |
292 | | - file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig |
293 | | -) -> dict[str, list[FunctionCalledInTest]]: |
294 | | - import jedi |
295 | | - |
| 292 | +def _process_single_test_file( |
| 293 | + test_file: Path, |
| 294 | + functions: list[TestsInFile], |
| 295 | + cfg: TestConfig, |
| 296 | + jedi_project: jedi.Project, |
| 297 | + goto_cache: dict, |
| 298 | + tests_cache: TestsCache, |
| 299 | + function_to_test_map: defaultdict, |
| 300 | +) -> None: |
296 | 301 | project_root_path = cfg.project_root_path |
297 | 302 | test_framework = cfg.test_framework |
| 303 | + file_hash = TestsCache.compute_file_hash(test_file) |
| 304 | + cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash) |
| 305 | + if cached_tests: |
| 306 | + self_cur = tests_cache.cur |
| 307 | + self_cur.execute( |
| 308 | + "SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?", |
| 309 | + (str(test_file), file_hash), |
| 310 | + ) |
| 311 | + qualified_names = [row[0] for row in self_cur.fetchall()] |
| 312 | + for cached, qualified_name in zip(cached_tests, qualified_names): |
| 313 | + function_to_test_map[qualified_name].add(cached) |
| 314 | + return |
298 | 315 |
|
299 | | - function_to_test_map = defaultdict(set) |
300 | | - jedi_project = jedi.Project(path=project_root_path) |
301 | | - goto_cache = {} |
302 | | - tests_cache = TestsCache() |
| 316 | + try: |
| 317 | + script = jedi.Script(path=test_file, project=jedi_project) |
| 318 | + test_functions = set() |
303 | 319 |
|
304 | | - with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as ( |
305 | | - progress, |
306 | | - task_id, |
307 | | - ): |
308 | | - for test_file, functions in file_to_test_map.items(): |
309 | | - file_hash = TestsCache.compute_file_hash(test_file) |
310 | | - cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash) |
311 | | - if cached_tests: |
312 | | - self_cur = tests_cache.cur |
313 | | - self_cur.execute( |
314 | | - "SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?", |
315 | | - (str(test_file), file_hash), |
316 | | - ) |
317 | | - qualified_names = [row[0] for row in self_cur.fetchall()] |
318 | | - for cached, qualified_name in zip(cached_tests, qualified_names): |
319 | | - function_to_test_map[qualified_name].add(cached) |
320 | | - progress.advance(task_id) |
321 | | - continue |
| 320 | + all_names = script.get_names(all_scopes=True, references=True) |
| 321 | + all_defs = script.get_names(all_scopes=True, definitions=True) |
| 322 | + all_names_top = script.get_names(all_scopes=True) |
322 | 323 |
|
323 | | - try: |
324 | | - script = jedi.Script(path=test_file, project=jedi_project) |
325 | | - test_functions = set() |
| 324 | + top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} |
| 325 | + top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} |
| 326 | + except Exception as e: |
| 327 | + logger.debug(f"Failed to get jedi script for {test_file}: {e}") |
| 328 | + return |
| 329 | + |
| 330 | + if test_framework == "pytest": |
| 331 | + for function in functions: |
| 332 | + if "[" in function.test_function: |
| 333 | + function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0] |
| 334 | + parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1] |
| 335 | + if function_name in top_level_functions: |
| 336 | + test_functions.add(TestFunction(function_name, function.test_class, parameters, function.test_type)) |
| 337 | + elif function.test_function in top_level_functions: |
| 338 | + test_functions.add(TestFunction(function.test_function, function.test_class, None, function.test_type)) |
| 339 | + elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function): |
| 340 | + base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function) |
| 341 | + if base_name in top_level_functions: |
| 342 | + test_functions.add( |
| 343 | + TestFunction( |
| 344 | + function_name=base_name, |
| 345 | + test_class=function.test_class, |
| 346 | + parameters=function.test_function, |
| 347 | + test_type=function.test_type, |
| 348 | + ) |
| 349 | + ) |
326 | 350 |
|
327 | | - all_names = script.get_names(all_scopes=True, references=True) |
328 | | - all_defs = script.get_names(all_scopes=True, definitions=True) |
329 | | - all_names_top = script.get_names(all_scopes=True) |
| 351 | + elif test_framework == "unittest": |
| 352 | + functions_to_search = [elem.test_function for elem in functions] |
| 353 | + test_suites = {elem.test_class for elem in functions} |
330 | 354 |
|
331 | | - top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} |
332 | | - top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} |
333 | | - except Exception as e: |
334 | | - logger.debug(f"Failed to get jedi script for {test_file}: {e}") |
335 | | - progress.advance(task_id) |
336 | | - continue |
| 355 | + matching_names = test_suites & top_level_classes.keys() |
| 356 | + for matched_name in matching_names: |
| 357 | + for def_name in all_defs: |
| 358 | + if ( |
| 359 | + def_name.type == "function" |
| 360 | + and def_name.full_name is not None |
| 361 | + and f".{matched_name}." in def_name.full_name |
| 362 | + ): |
| 363 | + for function in functions_to_search: |
| 364 | + (is_parameterized, new_function, parameters) = discover_parameters_unittest(function) |
337 | 365 |
|
338 | | - if test_framework == "pytest": |
339 | | - for function in functions: |
340 | | - if "[" in function.test_function: |
341 | | - function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0] |
342 | | - parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1] |
343 | | - if function_name in top_level_functions: |
| 366 | + if is_parameterized and new_function == def_name.name: |
344 | 367 | test_functions.add( |
345 | | - TestFunction(function_name, function.test_class, parameters, function.test_type) |
| 368 | + TestFunction( |
| 369 | + function_name=def_name.name, |
| 370 | + test_class=matched_name, |
| 371 | + parameters=parameters, |
| 372 | + test_type=functions[0].test_type, |
| 373 | + ) |
346 | 374 | ) |
347 | | - elif function.test_function in top_level_functions: |
348 | | - test_functions.add( |
349 | | - TestFunction(function.test_function, function.test_class, None, function.test_type) |
350 | | - ) |
351 | | - elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function): |
352 | | - base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function) |
353 | | - if base_name in top_level_functions: |
| 375 | + elif function == def_name.name: |
354 | 376 | test_functions.add( |
355 | 377 | TestFunction( |
356 | | - function_name=base_name, |
357 | | - test_class=function.test_class, |
358 | | - parameters=function.test_function, |
359 | | - test_type=function.test_type, |
| 378 | + function_name=def_name.name, |
| 379 | + test_class=matched_name, |
| 380 | + parameters=None, |
| 381 | + test_type=functions[0].test_type, |
360 | 382 | ) |
361 | 383 | ) |
362 | 384 |
|
363 | | - elif test_framework == "unittest": |
364 | | - functions_to_search = [elem.test_function for elem in functions] |
365 | | - test_suites = {elem.test_class for elem in functions} |
366 | | - |
367 | | - matching_names = test_suites & top_level_classes.keys() |
368 | | - for matched_name in matching_names: |
369 | | - for def_name in all_defs: |
370 | | - if ( |
371 | | - def_name.type == "function" |
372 | | - and def_name.full_name is not None |
373 | | - and f".{matched_name}." in def_name.full_name |
374 | | - ): |
375 | | - for function in functions_to_search: |
376 | | - (is_parameterized, new_function, parameters) = discover_parameters_unittest(function) |
377 | | - |
378 | | - if is_parameterized and new_function == def_name.name: |
379 | | - test_functions.add( |
380 | | - TestFunction( |
381 | | - function_name=def_name.name, |
382 | | - test_class=matched_name, |
383 | | - parameters=parameters, |
384 | | - test_type=functions[0].test_type, |
385 | | - ) |
386 | | - ) |
387 | | - elif function == def_name.name: |
388 | | - test_functions.add( |
389 | | - TestFunction( |
390 | | - function_name=def_name.name, |
391 | | - test_class=matched_name, |
392 | | - parameters=None, |
393 | | - test_type=functions[0].test_type, |
394 | | - ) |
395 | | - ) |
396 | | - |
397 | | - test_functions_list = list(test_functions) |
398 | | - test_functions_raw = [elem.function_name for elem in test_functions_list] |
399 | | - |
400 | | - test_functions_by_name = defaultdict(list) |
401 | | - for i, func_name in enumerate(test_functions_raw): |
402 | | - test_functions_by_name[func_name].append(i) |
403 | | - |
404 | | - for name in all_names: |
405 | | - if name.full_name is None: |
406 | | - continue |
407 | | - m = FUNCTION_NAME_REGEX.search(name.full_name) |
408 | | - if not m: |
409 | | - continue |
410 | | - |
411 | | - scope = m.group(1) |
412 | | - if scope not in test_functions_by_name: |
413 | | - continue |
414 | | - |
415 | | - cache_key = (name.full_name, name.module_name) |
416 | | - try: |
417 | | - if cache_key in goto_cache: |
418 | | - definition = goto_cache[cache_key] |
419 | | - else: |
420 | | - definition = name.goto(follow_imports=True, follow_builtin_imports=False) |
421 | | - goto_cache[cache_key] = definition |
422 | | - except Exception as e: |
423 | | - logger.debug(str(e)) |
424 | | - continue |
425 | | - |
426 | | - if not definition or definition[0].type != "function": |
427 | | - continue |
428 | | - |
429 | | - definition_path = str(definition[0].module_path) |
430 | | - if ( |
431 | | - definition_path.startswith(str(project_root_path) + os.sep) |
432 | | - and definition[0].module_name != name.module_name |
433 | | - and definition[0].full_name is not None |
434 | | - ): |
435 | | - for index in test_functions_by_name[scope]: |
436 | | - scope_test_function = test_functions_list[index].function_name |
437 | | - scope_test_class = test_functions_list[index].test_class |
438 | | - scope_parameters = test_functions_list[index].parameters |
439 | | - test_type = test_functions_list[index].test_type |
440 | | - |
441 | | - if scope_parameters is not None: |
442 | | - if test_framework == "pytest": |
443 | | - scope_test_function += "[" + scope_parameters + "]" |
444 | | - if test_framework == "unittest": |
445 | | - scope_test_function += "_" + scope_parameters |
446 | | - |
447 | | - full_name_without_module_prefix = definition[0].full_name.replace( |
448 | | - definition[0].module_name + ".", "", 1 |
449 | | - ) |
450 | | - qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" |
| 385 | + test_functions_list = list(test_functions) |
| 386 | + test_functions_raw = [elem.function_name for elem in test_functions_list] |
| 387 | + |
| 388 | + test_functions_by_name = defaultdict(list) |
| 389 | + for i, func_name in enumerate(test_functions_raw): |
| 390 | + test_functions_by_name[func_name].append(i) |
| 391 | + |
| 392 | + for name in all_names: |
| 393 | + if name.full_name is None: |
| 394 | + continue |
| 395 | + m = FUNCTION_NAME_REGEX.search(name.full_name) |
| 396 | + if not m: |
| 397 | + continue |
| 398 | + |
| 399 | + scope = m.group(1) |
| 400 | + if scope not in test_functions_by_name: |
| 401 | + continue |
| 402 | + |
| 403 | + cache_key = (name.full_name, name.module_name) |
| 404 | + try: |
| 405 | + if cache_key in goto_cache: |
| 406 | + definition = goto_cache[cache_key] |
| 407 | + else: |
| 408 | + definition = name.goto(follow_imports=True, follow_builtin_imports=False) |
| 409 | + goto_cache[cache_key] = definition |
| 410 | + except Exception as e: |
| 411 | + logger.debug(str(e)) |
| 412 | + continue |
| 413 | + |
| 414 | + if not definition or definition[0].type != "function": |
| 415 | + continue |
451 | 416 |
|
452 | | - tests_cache.insert_test( |
453 | | - file_path=str(test_file), |
454 | | - file_hash=file_hash, |
455 | | - qualified_name_with_modules_from_root=qualified_name_with_modules_from_root, |
456 | | - function_name=scope, |
| 417 | + definition_path = str(definition[0].module_path) |
| 418 | + if ( |
| 419 | + definition_path.startswith(str(project_root_path) + os.sep) |
| 420 | + and definition[0].module_name != name.module_name |
| 421 | + and definition[0].full_name is not None |
| 422 | + ): |
| 423 | + for index in test_functions_by_name[scope]: |
| 424 | + scope_test_function = test_functions_list[index].function_name |
| 425 | + scope_test_class = test_functions_list[index].test_class |
| 426 | + scope_parameters = test_functions_list[index].parameters |
| 427 | + test_type = test_functions_list[index].test_type |
| 428 | + |
| 429 | + if scope_parameters is not None: |
| 430 | + if test_framework == "pytest": |
| 431 | + scope_test_function += "[" + scope_parameters + "]" |
| 432 | + if test_framework == "unittest": |
| 433 | + scope_test_function += "_" + scope_parameters |
| 434 | + |
| 435 | + full_name_without_module_prefix = definition[0].full_name.replace( |
| 436 | + definition[0].module_name + ".", "", 1 |
| 437 | + ) |
| 438 | + qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" |
| 439 | + |
| 440 | + tests_cache.insert_test( |
| 441 | + file_path=str(test_file), |
| 442 | + file_hash=file_hash, |
| 443 | + qualified_name_with_modules_from_root=qualified_name_with_modules_from_root, |
| 444 | + function_name=scope, |
| 445 | + test_class=scope_test_class, |
| 446 | + test_function=scope_test_function, |
| 447 | + test_type=test_type, |
| 448 | + line_number=name.line, |
| 449 | + col_number=name.column, |
| 450 | + ) |
| 451 | + |
| 452 | + function_to_test_map[qualified_name_with_modules_from_root].add( |
| 453 | + FunctionCalledInTest( |
| 454 | + tests_in_file=TestsInFile( |
| 455 | + test_file=test_file, |
457 | 456 | test_class=scope_test_class, |
458 | 457 | test_function=scope_test_function, |
459 | 458 | test_type=test_type, |
460 | | - line_number=name.line, |
461 | | - col_number=name.column, |
462 | | - ) |
| 459 | + ), |
| 460 | + position=CodePosition(line_no=name.line, col_no=name.column), |
| 461 | + ) |
| 462 | + ) |
463 | 463 |
|
464 | | - function_to_test_map[qualified_name_with_modules_from_root].add( |
465 | | - FunctionCalledInTest( |
466 | | - tests_in_file=TestsInFile( |
467 | | - test_file=test_file, |
468 | | - test_class=scope_test_class, |
469 | | - test_function=scope_test_function, |
470 | | - test_type=test_type, |
471 | | - ), |
472 | | - position=CodePosition(line_no=name.line, col_no=name.column), |
473 | | - ) |
474 | | - ) |
475 | 464 |
|
| 465 | +def process_test_files( |
| 466 | + file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig |
| 467 | +) -> dict[str, list[FunctionCalledInTest]]: |
| 468 | + function_to_test_map = defaultdict(set) |
| 469 | + jedi_project = jedi.Project(path=cfg.project_root_path) |
| 470 | + goto_cache = {} |
| 471 | + tests_cache = TestsCache() |
| 472 | + |
| 473 | + with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as ( |
| 474 | + progress, |
| 475 | + task_id, |
| 476 | + ): |
| 477 | + for test_file, functions in file_to_test_map.items(): |
| 478 | + _process_single_test_file( |
| 479 | + test_file, functions, cfg, jedi_project, goto_cache, tests_cache, function_to_test_map |
| 480 | + ) |
476 | 481 | progress.advance(task_id) |
477 | 482 |
|
478 | 483 | tests_cache.close() |
|
0 commit comments