diff --git a/aws_doc_sdk_examples_tools/fs.py b/aws_doc_sdk_examples_tools/fs.py index 803563c..e5eec01 100644 --- a/aws_doc_sdk_examples_tools/fs.py +++ b/aws_doc_sdk_examples_tools/fs.py @@ -28,7 +28,7 @@ def read(self, path: Path) -> str: pass @abstractmethod - def readlines(self, path: Path) -> List[str]: + def readlines(self, path: Path, encoding: str = "utf-8") -> List[str]: pass @abstractmethod @@ -56,12 +56,12 @@ def read(self, path: Path) -> str: with path.open("r", encoding="utf-8") as file: return file.read() - def readlines(self, path: Path) -> List[str]: - with path.open("r") as file: + def readlines(self, path: Path, encoding: str = "utf-8") -> List[str]: + with path.open("r", encoding=encoding) as file: return file.readlines() def write(self, path: Path, content: str): - with path.open("w") as file: + with path.open("w", encoding="utf-8") as file: file.write(content) def stat(self, path: Path) -> Stat: @@ -95,7 +95,7 @@ def glob(self, path: Path, glob: str) -> Generator[Path, None, None]: def read(self, path: Path) -> str: return self.fs[path] - def readlines(self, path: Path) -> List[str]: + def readlines(self, path: Path, encoding: str = "utf-8") -> List[str]: content = self.fs[path] return content.splitlines(keepends=True) diff --git a/aws_doc_sdk_examples_tools/snippets.py b/aws_doc_sdk_examples_tools/snippets.py index e4c24d6..85ea070 100644 --- a/aws_doc_sdk_examples_tools/snippets.py +++ b/aws_doc_sdk_examples_tools/snippets.py @@ -8,6 +8,7 @@ from .validator_config import skip from .file_utils import get_files, clear +from .fs import Fs, PathFs from .metadata import Example from .metadata_errors import MetadataErrors, MetadataError from .project_validator import ( @@ -145,16 +146,17 @@ def parse_snippets( return snippets, errors -def find_snippets(file: Path, prefix: str) -> Tuple[Dict[str, Snippet], MetadataErrors]: +def find_snippets( + file: Path, prefix: str, fs: Fs = PathFs() +) -> Tuple[Dict[str, Snippet], MetadataErrors]: errors = MetadataErrors() snippets: Dict[str, Snippet] = {} try: - with open(file, encoding="utf-8") as snippet_file: - try: - snippets, errs = parse_snippets(snippet_file.readlines(), file, prefix) - errors.extend(errs) - except UnicodeDecodeError as err: - errors.append(MetadataUnicodeError(file=file, err=err)) + lines = fs.readlines(file) + snippets, errs = parse_snippets(lines, file, prefix) + errors.extend(errs) + except UnicodeDecodeError as err: + errors.append(MetadataUnicodeError(file=file, err=err)) except FileNotFoundError: pass except Exception as e: @@ -163,12 +165,12 @@ def find_snippets(file: Path, prefix: str) -> Tuple[Dict[str, Snippet], Metadata def collect_snippets( - root: Path, prefix: str = "" + root: Path, prefix: str = "", fs: Fs = PathFs() ) -> Tuple[Dict[str, Snippet], MetadataErrors]: snippets: Dict[str, Snippet] = {} errors = MetadataErrors() - for file in get_files(root, skip): - snips, errs = find_snippets(file, prefix) + for file in get_files(root, skip, fs=fs): + snips, errs = find_snippets(file, prefix, fs=fs) snippets.update(snips) errors.extend(errs) return snippets, errors @@ -180,6 +182,7 @@ def collect_snippet_files( prefix: str, errors: MetadataErrors, root: Path, + fs: Fs = PathFs(), ): for example in examples: for lang in example.languages: @@ -187,7 +190,9 @@ def collect_snippet_files( for version in language.versions: for excerpt in version.excerpts: for snippet_file in excerpt.snippet_files: - if not (root / snippet_file).exists(): + snippet_path = root / snippet_file + snippet_stat = fs.stat(snippet_path) + if not snippet_stat.exists: # Ensure all snippet_files exist errors.append( MissingSnippetFile( @@ -207,17 +212,14 @@ def collect_snippet_files( ) continue name = prefix + str(snippet_file).replace("/", ".") - with open(root / snippet_file, encoding="utf-8") as file: - code = file.readlines() - snippets[name] = Snippet( - id=name, - file=snippet_file, - line_start=0, - line_end=len(code), - code="".join( - strip_snippet_tags(strip_spdx_header(code)) - ), - ) + code = fs.readlines(snippet_path) + snippets[name] = Snippet( + id=name, + file=snippet_file, + line_start=0, + line_end=len(code), + code="".join(strip_snippet_tags(strip_spdx_header(code))), + ) def strip_snippet_tags(lines: List[str]) -> List[str]: diff --git a/aws_doc_sdk_examples_tools/snippets_test.py b/aws_doc_sdk_examples_tools/snippets_test.py index d18d579..23d0b44 100644 --- a/aws_doc_sdk_examples_tools/snippets_test.py +++ b/aws_doc_sdk_examples_tools/snippets_test.py @@ -5,6 +5,9 @@ from pathlib import Path from aws_doc_sdk_examples_tools import snippets +from aws_doc_sdk_examples_tools.fs import RecordFs +from aws_doc_sdk_examples_tools.metadata import Example, Language, Version, Excerpt +from aws_doc_sdk_examples_tools.metadata_errors import MetadataErrors @pytest.mark.parametrize( @@ -102,3 +105,159 @@ def test_strip_spdx_header(): ) assert [] == snippets.strip_spdx_header([]) + + +class TestFindSnippetsFs: + """Test find_snippets with filesystem abstraction.""" + + def test_find_snippets_with_recordfs(self): + """Test find_snippets using RecordFs.""" + fs = RecordFs( + { + Path( + "/project/test.py" + ): """# snippet-start:[example.hello] +def hello(): + print("Hello, World!") +# snippet-end:[example.hello] +""" + } + ) + + snippet_dict, errors = snippets.find_snippets( + Path("/project/test.py"), "", fs=fs + ) + + assert len(errors) == 0 + assert len(snippet_dict) == 1 + assert "example.hello" in snippet_dict + snippet = snippet_dict["example.hello"] + assert snippet.id == "example.hello" + assert "def hello():" in snippet.code + + def test_find_snippets_missing_file_graceful(self): + """Test find_snippets behavior with missing files.""" + fs = RecordFs({}) + + snippet_dict, errors = snippets.find_snippets( + Path("/project/missing.py"), "", fs=fs + ) + + # Missing files generate errors (not handled gracefully) + assert len(snippet_dict) == 0 + assert len(errors) == 1 # Should have a FileReadError + + +class TestCollectSnippetsFs: + """Test collect_snippets with filesystem abstraction.""" + + def test_collect_snippets_with_recordfs(self): + """Test collect_snippets using RecordFs.""" + fs = RecordFs( + { + Path( + "/project/src/file1.py" + ): """# snippet-start:[example1] +def example1(): + pass +# snippet-end:[example1] +""", + Path( + "/project/src/file2.py" + ): """# snippet-start:[example2] +def example2(): + pass +# snippet-end:[example2] +""", + } + ) + + snippet_dict, errors = snippets.collect_snippets(Path("/project/src"), fs=fs) + + assert len(errors) == 0 + assert len(snippet_dict) == 2 + assert "example1" in snippet_dict + assert "example2" in snippet_dict + + +class TestCollectSnippetFilesFs: + """Test collect_snippet_files with filesystem abstraction.""" + + def test_collect_snippet_files_with_recordfs(self): + """Test collect_snippet_files using RecordFs.""" + fs = RecordFs({Path("/project/example.py"): "print('Hello, World!')\n"}) + + example = Example( + id="test_example", + file=None, + languages={ + "python": Language( + name="python", + property="python", + versions=[ + Version( + sdk_version="3", + excerpts=[ + Excerpt( + description="Test excerpt", + snippet_tags=[], + snippet_files=["example.py"], + ) + ], + ) + ], + ) + }, + ) + + snippet_dict = {} + errors = MetadataErrors() + + snippets.collect_snippet_files( + [example], snippet_dict, "", errors, Path("/project"), fs=fs + ) + + assert len(errors) == 0 + assert len(snippet_dict) == 1 + assert "example.py" in snippet_dict + snippet = snippet_dict["example.py"] + assert snippet.file == "example.py" + assert "Hello, World!" in snippet.code + + def test_collect_snippet_files_missing_file_error(self): + """Test collect_snippet_files properly reports missing files as errors.""" + fs = RecordFs({}) # Empty filesystem + + example = Example( + id="test_example", + file=None, + languages={ + "python": Language( + name="python", + property="python", + versions=[ + Version( + sdk_version="3", + excerpts=[ + Excerpt( + description="Test excerpt", + snippet_tags=[], + snippet_files=["missing.py"], + ) + ], + ) + ], + ) + }, + ) + + snippet_dict = {} + errors = MetadataErrors() + + snippets.collect_snippet_files( + [example], snippet_dict, "", errors, Path("/project"), fs=fs + ) + + # Missing snippet files should generate errors (unlike find_snippets) + assert len(errors) == 1 + assert len(snippet_dict) == 0