Skip to content

Commit 3f23f8b

Browse files
Merge pull request #122 from r1chardj0n3s/pyfiles-return-path
Make pyfiles generate paths rather than strings
2 parents 5486d2a + a9c39e9 commit 3f23f8b

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

pip_check_reqs/common.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,15 @@ def finalise(self) -> Dict[str, FoundModule]:
124124
return result
125125

126126

127-
def pyfiles(root: Path) -> Generator[str, None, None]:
127+
def pyfiles(root: Path) -> Generator[Path, None, None]:
128128
if root.is_file():
129129
if root.suffix == ".py":
130-
yield str(root.absolute())
130+
yield root.absolute()
131131
else:
132132
raise ValueError(f"{root} is not a python file or directory")
133133
elif root.is_dir():
134134
for item in root.rglob("*.py"):
135-
yield str(item.absolute())
135+
yield item.absolute()
136136

137137

138138
def find_imported_modules(
@@ -143,14 +143,14 @@ def find_imported_modules(
143143
vis = _ImportVisitor(ignore_modules_function=ignore_modules_function)
144144
for path in paths:
145145
for filename in pyfiles(path):
146-
if ignore_files_function(filename):
146+
if ignore_files_function(str(filename)):
147147
log.info("ignoring: %s", os.path.relpath(filename))
148148
continue
149149
log.debug("scanning: %s", os.path.relpath(filename))
150150
with open(filename, encoding="utf-8") as file_obj:
151151
content = file_obj.read()
152-
vis.set_location(filename)
153-
vis.visit(ast.parse(content, filename))
152+
vis.set_location(str(filename))
153+
vis.visit(ast.parse(content, str(filename)))
154154
return vis.finalise()
155155

156156

tests/test_common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_import_visitor(stmt: str, result: List[str]) -> None:
6666
def test_pyfiles_file(tmp_path: Path) -> None:
6767
python_file = tmp_path / "example.py"
6868
python_file.touch()
69-
assert list(common.pyfiles(root=python_file)) == [str(python_file)]
69+
assert list(common.pyfiles(root=python_file)) == [python_file]
7070

7171

7272
def test_pyfiles_file_no_dice(tmp_path: Path) -> None:
@@ -89,8 +89,8 @@ def test_pyfiles_package(tmp_path: Path) -> None:
8989
not_python_file.touch()
9090

9191
assert list(common.pyfiles(root=tmp_path)) == [
92-
str(python_file),
93-
str(nested_python_file),
92+
python_file,
93+
nested_python_file,
9494
]
9595

9696

0 commit comments

Comments
 (0)