Skip to content

Commit a826252

Browse files
committed
Update test_code_context_extractor.py
1 parent 3c4573e commit a826252

File tree

1 file changed

+30
-31
lines changed

1 file changed

+30
-31
lines changed

tests/test_code_context_extractor.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def sort_from_another_file(arr):
266266
assert hashing_context.strip() == expected_hashing_context.strip()
267267

268268

269-
def test_flavio_typed_code_helper() -> None:
269+
def test_flavio_typed_code_helper(tmp_path: Path) -> None:
270270
code = '''
271271
272272
_P = ParamSpec("_P")
@@ -432,33 +432,32 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
432432
lifespan=self.__duration__,
433433
)
434434
'''
435-
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
436-
f.write(code)
437-
f.flush()
438-
file_path = Path(f.name).resolve()
439-
opt = Optimizer(
440-
Namespace(
441-
project_root=file_path.parent.resolve(),
442-
disable_telemetry=True,
443-
tests_root="tests",
444-
test_framework="pytest",
445-
pytest_cmd="pytest",
446-
experiment_id=None,
447-
test_project_root=Path().resolve(),
448-
)
449-
)
450-
function_to_optimize = FunctionToOptimize(
451-
function_name="__call__",
452-
file_path=file_path,
453-
parents=[FunctionParent(name="_PersistentCache", type="ClassDef")],
454-
starting_line=None,
455-
ending_line=None,
435+
# Create a temporary Python file using pytest's tmp_path fixture
436+
file_path = tmp_path / "test_code.py"
437+
file_path.write_text(code, encoding="utf-8")
438+
opt = Optimizer(
439+
Namespace(
440+
project_root=file_path.parent.resolve(),
441+
disable_telemetry=True,
442+
tests_root="tests",
443+
test_framework="pytest",
444+
pytest_cmd="pytest",
445+
experiment_id=None,
446+
test_project_root=Path().resolve(),
456447
)
448+
)
449+
function_to_optimize = FunctionToOptimize(
450+
function_name="__call__",
451+
file_path=file_path,
452+
parents=[FunctionParent(name="_PersistentCache", type="ClassDef")],
453+
starting_line=None,
454+
ending_line=None,
455+
)
457456

458-
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
459-
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
460-
hashing_context = code_ctx.hashing_code_context
461-
expected_read_write_context = f"""
457+
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
458+
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
459+
hashing_context = code_ctx.hashing_code_context
460+
expected_read_write_context = f"""
462461
```python:{file_path.relative_to(opt.args.project_root)}
463462
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
464463
@@ -556,7 +555,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
556555
)
557556
```
558557
"""
559-
expected_read_only_context = f'''
558+
expected_read_only_context = f'''
560559
```python:{file_path.relative_to(opt.args.project_root)}
561560
_P = ParamSpec("_P")
562561
_KEY_T = TypeVar("_KEY_T")
@@ -612,7 +611,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
612611
__backend__: _CacheBackendT
613612
```
614613
'''
615-
expected_hashing_context = f"""
614+
expected_hashing_context = f"""
616615
```python:{file_path.relative_to(opt.args.project_root)}
617616
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
618617
@@ -649,9 +648,9 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
649648
return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__)
650649
```
651650
"""
652-
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
653-
assert read_only_context.strip() == expected_read_only_context.strip()
654-
assert hashing_context.strip() == expected_hashing_context.strip()
651+
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
652+
assert read_only_context.strip() == expected_read_only_context.strip()
653+
assert hashing_context.strip() == expected_hashing_context.strip()
655654

656655

657656
def test_example_class() -> None:

0 commit comments

Comments
 (0)