diff --git a/aws_doc_sdk_examples_tools/doc_gen.py b/aws_doc_sdk_examples_tools/doc_gen.py index f819729..ef544de 100644 --- a/aws_doc_sdk_examples_tools/doc_gen.py +++ b/aws_doc_sdk_examples_tools/doc_gen.py @@ -16,6 +16,7 @@ # from os import glob from .categories import Category, parse as parse_categories +from .fs import Fs, PathFs from .metadata import ( Example, DocFilenames, @@ -55,6 +56,7 @@ class DocGenMergeWarning(MetadataError): class DocGen: root: Path errors: MetadataErrors + fs: Fs = field(default_factory=PathFs) entities: Dict[str, str] = field(default_factory=dict) prefix: Optional[str] = None validation: ValidationConfig = field(default_factory=ValidationConfig) @@ -171,8 +173,12 @@ def extend_examples(self, examples: Iterable[Example], errors: MetadataErrors): self.examples[id] = example @classmethod - def empty(cls, validation: ValidationConfig = ValidationConfig()) -> "DocGen": - return DocGen(root=Path("/"), errors=MetadataErrors(), validation=validation) + def empty( + cls, validation: ValidationConfig = ValidationConfig(), fs: Fs = PathFs() + ) -> "DocGen": + return DocGen( + root=Path("/"), errors=MetadataErrors(), validation=validation, fs=fs + ) @classmethod def default(cls) -> "DocGen": @@ -190,6 +196,7 @@ def clone(self) -> "DocGen": snippet_files=set(), cross_blocks=set(), examples={}, + fs=self.fs, ) def for_root( @@ -199,7 +206,7 @@ def for_root( config = config or Path(__file__).parent / "config" - doc_gen = DocGen.empty() + doc_gen = DocGen.empty(fs=self.fs) parse_config(doc_gen, root, config, self.validation.strict_titles) self.merge(doc_gen) @@ -209,31 +216,31 @@ def for_root( return self def find_and_process_metadata(self, metadata_path: Path): - for path in metadata_path.glob("*_metadata.yaml"): + for path in self.fs.glob(metadata_path, "*_metadata.yaml"): self.process_metadata(path) def process_metadata(self, path: Path) -> "DocGen": if path in self._loaded: return self try: - with open(path) as file: - examples, errs = parse_examples( - path, - yaml.safe_load(file), - self.sdks, - self.services, - self.standard_categories, - self.cross_blocks, - self.validation, - ) - self.extend_examples(examples, self.errors) - self.errors.extend(errs) - for example in examples: - for lang in example.languages: - language = example.languages[lang] - for version in language.versions: - for excerpt in version.excerpts: - self.snippet_files.update(excerpt.snippet_files) + content = self.fs.read(path) + examples, errs = parse_examples( + path, + yaml.safe_load(content), + self.sdks, + self.services, + self.standard_categories, + self.cross_blocks, + self.validation, + ) + self.extend_examples(examples, self.errors) + self.errors.extend(errs) + for example in examples: + for lang in example.languages: + language = example.languages[lang] + for version in language.versions: + for excerpt in version.excerpts: + self.snippet_files.update(excerpt.snippet_files) self._loaded.add(path) except ParserError as e: self.errors.append(YamlParseError(file=path, parser_error=str(e))) @@ -246,8 +253,9 @@ def from_root( config: Optional[Path] = None, validation: ValidationConfig = ValidationConfig(), incremental: bool = False, + fs: Fs = PathFs(), ) -> "DocGen": - return DocGen.empty(validation=validation).for_root( + return DocGen.empty(validation=validation, fs=fs).for_root( root, config, incremental=incremental ) @@ -348,6 +356,10 @@ def default(self, obj): "__entity_errors__": [{error.entity: error.message()} for error in obj] } + if isinstance(obj, Fs): + # Don't serialize filesystem objects for security + return {} + if isinstance(obj, set): return {"__set__": list(obj)} @@ -356,55 +368,53 @@ def default(self, obj): def parse_config(doc_gen: DocGen, root: Path, config: Path, strict: bool): try: - with open(root / ".doc_gen" / "validation.yaml", encoding="utf-8") as file: - validation = yaml.safe_load(file) - validation = validation or {} - doc_gen.validation.allow_list.update(validation.get("allow_list", [])) - doc_gen.validation.sample_files.update(validation.get("sample_files", [])) + content = doc_gen.fs.read(root / ".doc_gen" / "validation.yaml") + validation = yaml.safe_load(content) + validation = validation or {} + doc_gen.validation.allow_list.update(validation.get("allow_list", [])) + doc_gen.validation.sample_files.update(validation.get("sample_files", [])) except Exception: pass try: sdk_path = config / "sdks.yaml" - with sdk_path.open(encoding="utf-8") as file: - meta = yaml.safe_load(file) - sdks, errs = parse_sdks(sdk_path, meta, strict) - doc_gen.sdks = sdks - doc_gen.errors.extend(errs) + content = doc_gen.fs.read(sdk_path) + meta = yaml.safe_load(content) + sdks, errs = parse_sdks(sdk_path, meta, strict) + doc_gen.sdks = sdks + doc_gen.errors.extend(errs) except Exception: pass try: services_path = config / "services.yaml" - with services_path.open(encoding="utf-8") as file: - meta = yaml.safe_load(file) - services, service_errors = parse_services(services_path, meta) - doc_gen.services = services - for service in doc_gen.services.values(): - if service.expanded: - doc_gen.entities[service.long] = service.expanded.long - doc_gen.entities[service.short] = service.expanded.short - doc_gen.errors.extend(service_errors) + content = doc_gen.fs.read(services_path) + meta = yaml.safe_load(content) + services, service_errors = parse_services(services_path, meta) + doc_gen.services = services + for service in doc_gen.services.values(): + if service.expanded: + doc_gen.entities[service.long] = service.expanded.long + doc_gen.entities[service.short] = service.expanded.short + doc_gen.errors.extend(service_errors) except Exception: pass try: categories_path = config / "categories.yaml" - with categories_path.open(encoding="utf-8") as file: - meta = yaml.safe_load(file) - standard_categories, categories, errs = parse_categories( - categories_path, meta - ) - doc_gen.standard_categories = standard_categories - doc_gen.categories = categories - doc_gen.errors.extend(errs) + content = doc_gen.fs.read(categories_path) + meta = yaml.safe_load(content) + standard_categories, categories, errs = parse_categories(categories_path, meta) + doc_gen.standard_categories = standard_categories + doc_gen.categories = categories + doc_gen.errors.extend(errs) except Exception: pass try: entities_config_path = config / "entities.yaml" - with entities_config_path.open(encoding="utf-8") as file: - entities_config = yaml.safe_load(file) + content = doc_gen.fs.read(entities_config_path) + entities_config = yaml.safe_load(content) for entity, expanded in entities_config["expanded_override"].items(): doc_gen.entities[entity] = expanded except Exception: @@ -412,8 +422,9 @@ def parse_config(doc_gen: DocGen, root: Path, config: Path, strict: bool): metadata = root / ".doc_gen/metadata" try: + cross_content_path = metadata.parent / "cross-content" doc_gen.cross_blocks = set( - [path.name for path in (metadata.parent / "cross-content").glob("*.xml")] + [path.name for path in doc_gen.fs.glob(cross_content_path, "*.xml")] ) except Exception: pass diff --git a/aws_doc_sdk_examples_tools/doc_gen_test.py b/aws_doc_sdk_examples_tools/doc_gen_test.py index f1bb60a..5d6db57 100644 --- a/aws_doc_sdk_examples_tools/doc_gen_test.py +++ b/aws_doc_sdk_examples_tools/doc_gen_test.py @@ -16,6 +16,9 @@ from .sdks import Sdk, SdkVersion from .services import Service, ServiceExpanded from .snippets import Snippet +from .fs import PathFs + +SHARED_FS = PathFs() @pytest.mark.parametrize( @@ -24,6 +27,7 @@ ( DocGen( root=Path("/a"), + fs=SHARED_FS, errors=MetadataErrors(), sdks={ "a": Sdk( @@ -43,6 +47,7 @@ ), DocGen( root=Path("/b"), + fs=SHARED_FS, errors=MetadataErrors(), sdks={ "b": Sdk( @@ -62,6 +67,7 @@ ), DocGen( root=Path("/a"), + fs=SHARED_FS, errors=MetadataErrors(), sdks={ "a": Sdk( diff --git a/aws_doc_sdk_examples_tools/file_utils.py b/aws_doc_sdk_examples_tools/file_utils.py index 71039f4..bd78103 100644 --- a/aws_doc_sdk_examples_tools/file_utils.py +++ b/aws_doc_sdk_examples_tools/file_utils.py @@ -8,6 +8,7 @@ from shutil import rmtree from pathspec import GitIgnoreSpec +from aws_doc_sdk_examples_tools.fs import Fs, PathFs def match_path_to_specs(path: Path, specs: List[GitIgnoreSpec]) -> bool: @@ -21,7 +22,7 @@ def match_path_to_specs(path: Path, specs: List[GitIgnoreSpec]) -> bool: def walk_with_gitignore( - root: Path, specs: List[GitIgnoreSpec] = [] + root: Path, specs: List[GitIgnoreSpec] = [], fs: Fs = PathFs() ) -> Generator[Path, None, None]: """ Starting from a root directory, walk the file system yielding a path for each file. @@ -30,27 +31,31 @@ def walk_with_gitignore( fiddling with a number of flags. """ gitignore = root / ".gitignore" - if gitignore.exists(): - with open(root / ".gitignore", "r", encoding="utf-8") as ignore_file: - specs = [*specs, GitIgnoreSpec.from_lines(ignore_file.readlines())] - for entry in os.scandir(root): - path = Path(entry.path) + gitignore_stat = fs.stat(gitignore) + if gitignore_stat.exists: + lines = fs.readlines(gitignore) + specs = [*specs, GitIgnoreSpec.from_lines(lines)] + + for path in fs.list(root): if not match_path_to_specs(path, specs): - if entry.is_dir(): - yield from walk_with_gitignore(path, specs) + path_stat = fs.stat(path) + if path_stat.is_dir: + yield from walk_with_gitignore(path, specs, fs) else: - yield path + # Don't yield .gitignore files themselves + if path.name != ".gitignore": + yield path def get_files( - root: Path, skip: Callable[[Path], bool] = lambda _: False + root: Path, skip: Callable[[Path], bool] = lambda _: False, fs: Fs = PathFs() ) -> Generator[Path, None, None]: """ Yield non-skipped files, that is, anything not matching git ls-files and not in the "to skip" files that are in git but are machine generated, so we don't want to validate them. """ - for path in walk_with_gitignore(root): + for path in walk_with_gitignore(root, fs=fs): if not skip(path): yield path diff --git a/aws_doc_sdk_examples_tools/file_utils_test.py b/aws_doc_sdk_examples_tools/file_utils_test.py new file mode 100644 index 0000000..289781c --- /dev/null +++ b/aws_doc_sdk_examples_tools/file_utils_test.py @@ -0,0 +1,187 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Tests for file_utils.py with filesystem abstraction. +""" + +from pathlib import Path + +from aws_doc_sdk_examples_tools.fs import RecordFs +from aws_doc_sdk_examples_tools.file_utils import walk_with_gitignore, get_files + + +class TestWalkWithGitignore: + """Test walk_with_gitignore with RecordFs.""" + + def test_basic_directory_traversal(self): + """Test basic directory traversal without gitignore.""" + fs = RecordFs( + { + Path("/root/file1.py"): "print('file1')", + Path("/root/file2.js"): "console.log('file2')", + } + ) + + files = list(walk_with_gitignore(Path("/root"), fs=fs)) + + expected = [ + Path("/root/file1.py"), + Path("/root/file2.js"), + ] + assert sorted(files) == sorted(expected) + + def test_gitignore_filtering(self): + """Test that gitignore rules are applied correctly.""" + fs = RecordFs( + { + Path("/root/.gitignore"): "*.tmp\n*.log\n", + Path("/root/keep.py"): "print('keep')", + Path("/root/ignore.tmp"): "temporary", + Path("/root/keep.js"): "console.log('keep')", + Path("/root/debug.log"): "log content", + } + ) + + files = list(walk_with_gitignore(Path("/root"), fs=fs)) + + # .gitignore files should not be included in results + expected = [ + Path("/root/keep.py"), + Path("/root/keep.js"), + ] + assert sorted(files) == sorted(expected) + + def test_no_gitignore_file(self): + """Test directory traversal when no .gitignore exists.""" + fs = RecordFs( + { + Path("/root/file1.py"): "print('file1')", + Path("/root/file2.js"): "console.log('file2')", + Path("/root/file3.txt"): "text content", + } + ) + + files = list(walk_with_gitignore(Path("/root"), fs=fs)) + + expected = [ + Path("/root/file1.py"), + Path("/root/file2.js"), + Path("/root/file3.txt"), + ] + assert sorted(files) == sorted(expected) + + def test_empty_directory(self): + """Test walking an empty directory.""" + fs = RecordFs({}) + + files = list(walk_with_gitignore(Path("/empty"), fs=fs)) + + assert files == [] + + def test_directory_with_only_gitignore(self): + """Test directory that only contains .gitignore file.""" + fs = RecordFs( + { + Path("/root/.gitignore"): "*.tmp\n", + } + ) + + files = list(walk_with_gitignore(Path("/root"), fs=fs)) + + assert files == [] + + def test_nested_gitignores(self): + """Test nested gitignore files with different rules.""" + fs = RecordFs( + { + # Root level gitignore ignores *.log files + Path("/root/.gitignore"): "*.log\n", + Path("/root/keep.py"): "print('keep')", + Path("/root/debug.log"): "root log", # Should be ignored + # Nested directory with its own gitignore ignoring *.tmp files + Path("/root/subdir/.gitignore"): "*.tmp\n", + Path("/root/subdir/keep.js"): "console.log('keep')", + Path( + "/root/subdir/ignore.tmp" + ): "temporary", # Should be ignored by subdir gitignore + Path( + "/root/subdir/keep.log" + ): "nested log", # Should be ignored by root gitignore + } + ) + + files = list(walk_with_gitignore(Path("/root"), fs=fs)) + + # Only files that don't match any gitignore pattern should be returned + expected = [ + Path("/root/keep.py"), + Path("/root/subdir/keep.js"), + ] + assert sorted(files) == sorted(expected) + + +class TestGetFiles: + """Test get_files with RecordFs.""" + + def test_get_files_basic(self): + """Test basic get_files functionality.""" + fs = RecordFs( + { + Path("/root/file1.py"): "print('file1')", + Path("/root/file2.js"): "console.log('file2')", + } + ) + + files = list(get_files(Path("/root"), fs=fs)) + + expected = [ + Path("/root/file1.py"), + Path("/root/file2.js"), + ] + assert sorted(files) == sorted(expected) + + def test_get_files_with_skip_function(self): + """Test get_files with skip function.""" + fs = RecordFs( + { + Path("/root/keep.py"): "print('keep')", + Path("/root/skip.py"): "print('skip')", + Path("/root/keep.js"): "console.log('keep')", + Path("/root/skip.js"): "console.log('skip')", + } + ) + + def skip_function(path: Path) -> bool: + return "skip" in path.name + + files = list(get_files(Path("/root"), skip=skip_function, fs=fs)) + + expected = [ + Path("/root/keep.py"), + Path("/root/keep.js"), + ] + assert sorted(files) == sorted(expected) + + def test_get_files_with_gitignore_and_skip(self): + """Test get_files with both gitignore and skip function.""" + fs = RecordFs( + { + Path("/root/.gitignore"): "*.tmp\n", + Path("/root/keep.py"): "print('keep')", + Path("/root/skip.py"): "print('skip')", + Path("/root/ignore.tmp"): "temporary", + Path("/root/keep.js"): "console.log('keep')", + } + ) + + def skip_function(path: Path) -> bool: + return "skip" in path.name + + files = list(get_files(Path("/root"), skip=skip_function, fs=fs)) + + expected = [ + Path("/root/keep.py"), + Path("/root/keep.js"), + ] + assert sorted(files) == sorted(expected) diff --git a/aws_doc_sdk_examples_tools/fs.py b/aws_doc_sdk_examples_tools/fs.py index e980e60..803563c 100644 --- a/aws_doc_sdk_examples_tools/fs.py +++ b/aws_doc_sdk_examples_tools/fs.py @@ -27,6 +27,10 @@ def glob(self, path: Path, glob: str) -> Generator[Path, None, None]: def read(self, path: Path) -> str: pass + @abstractmethod + def readlines(self, path: Path) -> List[str]: + pass + @abstractmethod def write(self, path: Path, content: str): pass @@ -49,9 +53,13 @@ def glob(self, path: Path, glob: str) -> Generator[Path, None, None]: return path.glob(glob) def read(self, path: Path) -> str: - with path.open("r") as file: + with path.open("r", encoding="utf-8") as file: return file.read() + def readlines(self, path: Path) -> List[str]: + with path.open("r") as file: + return file.readlines() + def write(self, path: Path, content: str): with path.open("w") as file: file.write(content) @@ -87,6 +95,10 @@ def glob(self, path: Path, glob: str) -> Generator[Path, None, None]: def read(self, path: Path) -> str: return self.fs[path] + def readlines(self, path: Path) -> List[str]: + content = self.fs[path] + return content.splitlines(keepends=True) + def write(self, path: Path, content: str): base = str(path.parent) assert any( @@ -106,7 +118,24 @@ def mkdir(self, path: Path): self.fs.setdefault(path, "") def list(self, path: Path) -> List[Path]: - return [item for item in self.fs.keys() if item.parent == path] + # If it's a file, return an empty list + if self.stat(path).is_file: + return [] + + # Gather all entries that are immediate children of `path` + prefix = str(path).rstrip("/") + "/" + children = set() + + for item in self.fs.keys(): + item_s = str(item) + if item_s.startswith(prefix): + # Determine the remainder path after the prefix + remainder = item_s[len(prefix) :] + # Split off the first component + first_part = remainder.split("/", 1)[0] + children.add(Path(prefix + first_part)) + + return sorted(children) fs = PathFs() diff --git a/aws_doc_sdk_examples_tools/fs_test.py b/aws_doc_sdk_examples_tools/fs_test.py new file mode 100644 index 0000000..8f628fa --- /dev/null +++ b/aws_doc_sdk_examples_tools/fs_test.py @@ -0,0 +1,107 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Tests for the Fs interface, particularly the readlines functionality. +""" + +import pytest +import tempfile +from pathlib import Path +from typing import List + +from .fs import Fs, PathFs, RecordFs + + +def assert_readlines_result(fs: Fs, path: Path, expected: List[str]): + """Generic assertion for readlines results.""" + lines = fs.readlines(path) + assert lines == expected + assert len(lines) == len(expected) + + +def run_common_readlines_scenarios(fs: Fs, path_factory): + """Test common readlines scenarios for any Fs implementation.""" + # Basic multi-line content + path = path_factory("Line 1\nLine 2\nLine 3\n") + assert_readlines_result(fs, path, ["Line 1\n", "Line 2\n", "Line 3\n"]) + + # Empty file + path = path_factory("") + assert_readlines_result(fs, path, []) + + # No final newline + path = path_factory("Line 1\nLine 2") + assert_readlines_result(fs, path, ["Line 1\n", "Line 2"]) + + # Single line + path = path_factory("Single line\n") + assert_readlines_result(fs, path, ["Single line\n"]) + + +class TestPathFs: + """Test PathFs implementation of readlines.""" + + def test_readlines_scenarios(self): + """Test various readlines scenarios with PathFs.""" + fs = PathFs() + temp_files = [] + + def path_factory(content: str) -> Path: + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".txt" + ) as f: + f.write(content) + path = Path(f.name) + temp_files.append(path) + return path + + try: + run_common_readlines_scenarios(fs, path_factory) + finally: + # Clean up temp files + errors = [] + for path in temp_files: + try: + if path.exists(): + path.unlink() + except Exception as e: + errors.append((path, e)) + if errors: + messages = "\n".join( + f"{path}: {type(e).__name__}: {e}" for path, e in errors + ) + pytest.fail( + f"Errors occurred while cleaning up temp files:\n{messages}" + ) + + +class TestRecordFs: + """Test RecordFs implementation of readlines.""" + + def test_readlines_scenarios(self): + """Test various readlines scenarios with RecordFs.""" + test_cases = [ + ("Line 1\nLine 2\nLine 3\n", ["Line 1\n", "Line 2\n", "Line 3\n"]), + ("", []), + ("Line 1\nLine 2", ["Line 1\n", "Line 2"]), + ("Single line\n", ["Single line\n"]), + ] + + for content, expected in test_cases: + fs = RecordFs({Path("test.txt"): content}) + assert_readlines_result(fs, Path("test.txt"), expected) + + def test_readlines_line_ending_variations(self): + """Test readlines with different line ending styles.""" + test_cases = [ + ( + "Line 1\r\nLine 2\r\nLine 3\r\n", + ["Line 1\r\n", "Line 2\r\n", "Line 3\r\n"], + ), + ("Line 1\nLine 2\r\nLine 3\n", ["Line 1\n", "Line 2\r\n", "Line 3\n"]), + ] + + for content, expected in test_cases: + fs = RecordFs({Path("test.txt"): content}) + assert_readlines_result(fs, Path("test.txt"), expected)