Skip to content

Commit e77c608

Browse files
committed
Make pyfiles take a pathlib.Path rather than a str
1 parent 3eeae27 commit e77c608

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

pip_check_reqs/common.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,14 @@ def finalise(self) -> Dict[str, FoundModule]:
124124
return result
125125

126126

127-
def pyfiles(root: str) -> Generator[str, None, None]:
128-
root_path = Path(root)
129-
if root_path.is_file():
130-
if root_path.suffix == ".py":
131-
yield str(root_path.absolute())
127+
def pyfiles(root: Path) -> Generator[str, None, None]:
128+
if root.is_file():
129+
if root.suffix == ".py":
130+
yield str(root.absolute())
132131
else:
133-
raise ValueError(f"{root_path} is not a python file or directory")
134-
elif root_path.is_dir():
135-
for item in root_path.rglob("*.py"):
132+
raise ValueError(f"{root} is not a python file or directory")
133+
elif root.is_dir():
134+
for item in root.rglob("*.py"):
136135
yield str(item.absolute())
137136

138137

@@ -143,7 +142,7 @@ def find_imported_modules(
143142
) -> Dict[str, FoundModule]:
144143
vis = _ImportVisitor(ignore_modules_function=ignore_modules_function)
145144
for path in paths:
146-
for filename in pyfiles(path):
145+
for filename in pyfiles(Path(path)):
147146
if ignore_files_function(filename):
148147
log.info("ignoring: %s", os.path.relpath(filename))
149148
continue

tests/test_common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,15 @@ def test_import_visitor(stmt: str, result: List[str]) -> None:
6666
def test_pyfiles_file(tmp_path: Path) -> None:
6767
python_file = tmp_path / "example.py"
6868
python_file.touch()
69-
assert list(common.pyfiles(root=str(python_file))) == [str(python_file)]
69+
assert list(common.pyfiles(root=python_file)) == [str(python_file)]
7070

7171

7272
def test_pyfiles_file_no_dice(tmp_path: Path) -> None:
7373
not_python_file = tmp_path / "example"
7474
not_python_file.touch()
7575

7676
with pytest.raises(ValueError):
77-
list(common.pyfiles(root=str(not_python_file)))
77+
list(common.pyfiles(root=not_python_file))
7878

7979

8080
def test_pyfiles_package(tmp_path: Path) -> None:
@@ -88,7 +88,7 @@ def test_pyfiles_package(tmp_path: Path) -> None:
8888

8989
not_python_file.touch()
9090

91-
assert list(common.pyfiles(root=str(tmp_path))) == [
91+
assert list(common.pyfiles(root=tmp_path)) == [
9292
str(python_file),
9393
str(nested_python_file),
9494
]

0 commit comments

Comments
 (0)