Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions aws_doc_sdk_examples_tools/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def read(self, path: Path) -> str:
pass

@abstractmethod
def readlines(self, path: Path) -> List[str]:
def readlines(self, path: Path, encoding: str = "utf-8") -> List[str]:
pass

@abstractmethod
Expand Down Expand Up @@ -56,12 +56,12 @@ def read(self, path: Path) -> str:
with path.open("r", encoding="utf-8") as file:
return file.read()

def readlines(self, path: Path) -> List[str]:
with path.open("r") as file:
def readlines(self, path: Path, encoding: str = "utf-8") -> List[str]:
with path.open("r", encoding=encoding) as file:
return file.readlines()

def write(self, path: Path, content: str):
with path.open("w") as file:
with path.open("w", encoding="utf-8") as file:
file.write(content)

def stat(self, path: Path) -> Stat:
Expand Down Expand Up @@ -95,7 +95,7 @@ def glob(self, path: Path, glob: str) -> Generator[Path, None, None]:
def read(self, path: Path) -> str:
return self.fs[path]

def readlines(self, path: Path) -> List[str]:
def readlines(self, path: Path, encoding: str = "utf-8") -> List[str]:
content = self.fs[path]
return content.splitlines(keepends=True)

Expand Down
46 changes: 24 additions & 22 deletions aws_doc_sdk_examples_tools/snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .validator_config import skip
from .file_utils import get_files, clear
from .fs import Fs, PathFs
from .metadata import Example
from .metadata_errors import MetadataErrors, MetadataError
from .project_validator import (
Expand Down Expand Up @@ -145,16 +146,17 @@ def parse_snippets(
return snippets, errors


def find_snippets(file: Path, prefix: str) -> Tuple[Dict[str, Snippet], MetadataErrors]:
def find_snippets(
file: Path, prefix: str, fs: Fs = PathFs()
) -> Tuple[Dict[str, Snippet], MetadataErrors]:
errors = MetadataErrors()
snippets: Dict[str, Snippet] = {}
try:
with open(file, encoding="utf-8") as snippet_file:
try:
snippets, errs = parse_snippets(snippet_file.readlines(), file, prefix)
errors.extend(errs)
except UnicodeDecodeError as err:
errors.append(MetadataUnicodeError(file=file, err=err))
lines = fs.readlines(file)
snippets, errs = parse_snippets(lines, file, prefix)
errors.extend(errs)
except UnicodeDecodeError as err:
errors.append(MetadataUnicodeError(file=file, err=err))
except FileNotFoundError:
pass
except Exception as e:
Expand All @@ -163,12 +165,12 @@ def find_snippets(file: Path, prefix: str) -> Tuple[Dict[str, Snippet], Metadata


def collect_snippets(
root: Path, prefix: str = ""
root: Path, prefix: str = "", fs: Fs = PathFs()
) -> Tuple[Dict[str, Snippet], MetadataErrors]:
snippets: Dict[str, Snippet] = {}
errors = MetadataErrors()
for file in get_files(root, skip):
snips, errs = find_snippets(file, prefix)
for file in get_files(root, skip, fs=fs):
snips, errs = find_snippets(file, prefix, fs=fs)
snippets.update(snips)
errors.extend(errs)
return snippets, errors
Expand All @@ -180,14 +182,17 @@ def collect_snippet_files(
prefix: str,
errors: MetadataErrors,
root: Path,
fs: Fs = PathFs(),
):
for example in examples:
for lang in example.languages:
language = example.languages[lang]
for version in language.versions:
for excerpt in version.excerpts:
for snippet_file in excerpt.snippet_files:
if not (root / snippet_file).exists():
snippet_path = root / snippet_file
snippet_stat = fs.stat(snippet_path)
if not snippet_stat.exists:
# Ensure all snippet_files exist
errors.append(
MissingSnippetFile(
Expand All @@ -207,17 +212,14 @@ def collect_snippet_files(
)
continue
name = prefix + str(snippet_file).replace("/", ".")
with open(root / snippet_file, encoding="utf-8") as file:
code = file.readlines()
snippets[name] = Snippet(
id=name,
file=snippet_file,
line_start=0,
line_end=len(code),
code="".join(
strip_snippet_tags(strip_spdx_header(code))
),
)
code = fs.readlines(snippet_path)
snippets[name] = Snippet(
id=name,
file=snippet_file,
line_start=0,
line_end=len(code),
code="".join(strip_snippet_tags(strip_spdx_header(code))),
)


def strip_snippet_tags(lines: List[str]) -> List[str]:
Expand Down
159 changes: 159 additions & 0 deletions aws_doc_sdk_examples_tools/snippets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from pathlib import Path

from aws_doc_sdk_examples_tools import snippets
from aws_doc_sdk_examples_tools.fs import RecordFs
from aws_doc_sdk_examples_tools.metadata import Example, Language, Version, Excerpt
from aws_doc_sdk_examples_tools.metadata_errors import MetadataErrors


@pytest.mark.parametrize(
Expand Down Expand Up @@ -102,3 +105,159 @@ def test_strip_spdx_header():
)

assert [] == snippets.strip_spdx_header([])


class TestFindSnippetsFs:
"""Test find_snippets with filesystem abstraction."""

def test_find_snippets_with_recordfs(self):
"""Test find_snippets using RecordFs."""
fs = RecordFs(
{
Path(
"/project/test.py"
): """# snippet-start:[example.hello]
def hello():
print("Hello, World!")
# snippet-end:[example.hello]
"""
}
)

snippet_dict, errors = snippets.find_snippets(
Path("/project/test.py"), "", fs=fs
)

assert len(errors) == 0
assert len(snippet_dict) == 1
assert "example.hello" in snippet_dict
snippet = snippet_dict["example.hello"]
assert snippet.id == "example.hello"
assert "def hello():" in snippet.code

def test_find_snippets_missing_file_graceful(self):
"""Test find_snippets behavior with missing files."""
fs = RecordFs({})

snippet_dict, errors = snippets.find_snippets(
Path("/project/missing.py"), "", fs=fs
)

# Missing files generate errors (not handled gracefully)
assert len(snippet_dict) == 0
assert len(errors) == 1 # Should have a FileReadError


class TestCollectSnippetsFs:
"""Test collect_snippets with filesystem abstraction."""

def test_collect_snippets_with_recordfs(self):
"""Test collect_snippets using RecordFs."""
fs = RecordFs(
{
Path(
"/project/src/file1.py"
): """# snippet-start:[example1]
def example1():
pass
# snippet-end:[example1]
""",
Path(
"/project/src/file2.py"
): """# snippet-start:[example2]
def example2():
pass
# snippet-end:[example2]
""",
}
)

snippet_dict, errors = snippets.collect_snippets(Path("/project/src"), fs=fs)

assert len(errors) == 0
assert len(snippet_dict) == 2
assert "example1" in snippet_dict
assert "example2" in snippet_dict


class TestCollectSnippetFilesFs:
"""Test collect_snippet_files with filesystem abstraction."""

def test_collect_snippet_files_with_recordfs(self):
"""Test collect_snippet_files using RecordFs."""
fs = RecordFs({Path("/project/example.py"): "print('Hello, World!')\n"})

example = Example(
id="test_example",
file=None,
languages={
"python": Language(
name="python",
property="python",
versions=[
Version(
sdk_version="3",
excerpts=[
Excerpt(
description="Test excerpt",
snippet_tags=[],
snippet_files=["example.py"],
)
],
)
],
)
},
)

snippet_dict = {}
errors = MetadataErrors()

snippets.collect_snippet_files(
[example], snippet_dict, "", errors, Path("/project"), fs=fs
)

assert len(errors) == 0
assert len(snippet_dict) == 1
assert "example.py" in snippet_dict
snippet = snippet_dict["example.py"]
assert snippet.file == "example.py"
assert "Hello, World!" in snippet.code

def test_collect_snippet_files_missing_file_error(self):
"""Test collect_snippet_files properly reports missing files as errors."""
fs = RecordFs({}) # Empty filesystem

example = Example(
id="test_example",
file=None,
languages={
"python": Language(
name="python",
property="python",
versions=[
Version(
sdk_version="3",
excerpts=[
Excerpt(
description="Test excerpt",
snippet_tags=[],
snippet_files=["missing.py"],
)
],
)
],
)
},
)

snippet_dict = {}
errors = MetadataErrors()

snippets.collect_snippet_files(
[example], snippet_dict, "", errors, Path("/project"), fs=fs
)

# Missing snippet files should generate errors (unlike find_snippets)
assert len(errors) == 1
assert len(snippet_dict) == 0