Skip to content

Commit ed6d44d

Browse files
authored
Validate template variables and improve defaults (frequenz-floss#95)
- Add some basic variable validation - Improve defaults built with `name` (fixes frequenz-floss#79) - Make sure name is converted to an identifier for files
2 parents b7ae9ee + be9b685 commit ed6d44d

File tree

20 files changed

+133
-26
lines changed

20 files changed

+133
-26
lines changed

cookiecutter/cookiecutter.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"default_codeowners": "(like @some-org/some-team; defaults to a team based on the repo type)",
2525
"_extensions": [
2626
"jinja2_time.TimeExtension",
27+
"local_extensions.as_identifier",
2728
"local_extensions.default_codeowners",
2829
"local_extensions.github_repo_name",
2930
"local_extensions.introduction",

cookiecutter/hooks/post_gen_project.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,9 +474,10 @@ def finish_lib_setup() -> None:
474474
- `lib`: `src/frequenz/{name}`
475475
- `rest`: `src/frequenz/{type}/{name}`
476476
"""
477+
name = cookiecutter.name.lower().replace("-", "_")
477478
recursive_overwrite_move(
478-
_pathlib.Path(f"src/frequenz/{cookiecutter.type}/{cookiecutter.name}"),
479-
_pathlib.Path(f"src/frequenz/{cookiecutter.name}"),
479+
_pathlib.Path(f"src/frequenz/{cookiecutter.type}/{name}"),
480+
_pathlib.Path(f"src/frequenz/{name}"),
480481
)
481482
_pathlib.Path(f"src/frequenz/{cookiecutter.type}").rmdir()
482483

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# License: MIT
2+
# Copyright © 2023 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Cookiecutter pre-generation hooks.
5+
6+
This module contains the pre-generation hooks for the cookiecutter template. It
7+
validates the cookiecutter variables and prints an error message and exits with a
8+
non-zero exit code if any of them are invalid.
9+
"""
10+
11+
import collections
12+
import json
13+
import re
14+
import sys
15+
from typing import Any
16+
17+
NAME_REGEX = re.compile(r"^[a-zA-Z][_a-zA-Z0-9]+(-[_a-zA-Z][_a-zA-Z0-9]+)*$")
18+
PYTHON_PACKAGE_REGEX = re.compile(r"^[a-zA-Z][_a-zA-Z0-9]+(\.[_a-zA-Z][_a-zA-Z0-9]+)*$")
19+
PYPI_PACKAGE_REGEX = NAME_REGEX
20+
21+
22+
def to_named_tuple(dictionary: dict[Any, Any], /) -> Any:
23+
"""Convert a dictionary to a named tuple.
24+
25+
Args:
26+
dictionary: The dictionary to convert.
27+
28+
Returns:
29+
The named tuple with the same keys and values as the dictionary.
30+
"""
31+
filtered = {k: v for k, v in dictionary.items() if not k.startswith("_")}
32+
return collections.namedtuple("Cookiecutter", filtered.keys())(*filtered.values())
33+
34+
35+
cookiecutter = to_named_tuple(json.loads(r"""{{cookiecutter | tojson}}"""))
36+
37+
38+
def main() -> None:
39+
"""Validate the cookiecutter variables.
40+
41+
This function validates the cookiecutter variables and prints an error message and
42+
exits with a non-zero exit code if any of them are invalid.
43+
"""
44+
errors: dict[str, list[str]] = {}
45+
46+
def add_error(key: str, message: str) -> None:
47+
"""Add an error to the error dictionary.
48+
49+
Args:
50+
key: The key of the error.
51+
message: The error message.
52+
"""
53+
errors.setdefault(key, []).append(message)
54+
55+
if not NAME_REGEX.match(cookiecutter.name):
56+
add_error("name", f"Invalid project name (must match {NAME_REGEX.pattern})")
57+
58+
if not PYTHON_PACKAGE_REGEX.match(cookiecutter.python_package):
59+
add_error(
60+
"python_package",
61+
f"Invalid package name (must match {PYTHON_PACKAGE_REGEX.pattern})",
62+
)
63+
64+
if not PYPI_PACKAGE_REGEX.match(cookiecutter.pypi_package_name):
65+
add_error(
66+
"pypi_package_name",
67+
f"Invalid package name (must match {PYPI_PACKAGE_REGEX.pattern})",
68+
)
69+
70+
if errors:
71+
print("The following errors were found:", file=sys.stderr)
72+
for key, messages in errors.items():
73+
print(f" {key}:", file=sys.stderr)
74+
for message in messages:
75+
print(f" - {message}", file=sys.stderr)
76+
sys.exit(1)
77+
78+
79+
if __name__ == "__main__":
80+
main()

cookiecutter/local_extensions.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ def _build_identifier(repo_type: str, name: str, separator: str) -> str:
2424
Returns:
2525
The built identifier.
2626
"""
27+
if separator == ".":
28+
name = name.replace("-", "_")
29+
if separator == "-":
30+
name = name.replace("_", "-")
31+
name = name.lower()
2732
middle = f"{repo_type}{separator}" if repo_type != "lib" else ""
2833
return f"frequenz{separator}{middle}{name}"
2934

@@ -42,6 +47,19 @@ def _get_from_json(key: str) -> str:
4247

4348

4449
# Ignoring because cookiecutter simple_filter decorator is not typed.
50+
@_simple_filter # type: ignore[misc]
51+
def as_identifier(name: str) -> str:
52+
"""Convert a name to a valid identifier.
53+
54+
Args:
55+
name: The name to convert.
56+
57+
Returns:
58+
The converted identifier.
59+
"""
60+
return name.lower().replace("-", "_")
61+
62+
4563
@_simple_filter # type: ignore[misc]
4664
def python_package(cookiecutter: dict[str, str]) -> str:
4765
"""Generate the Python package (import) depending on the repository type.
@@ -93,7 +111,7 @@ def title(cookiecutter: dict[str, str]) -> str:
93111
Returns:
94112
The default site name.
95113
"""
96-
name = cookiecutter["name"].capitalize()
114+
name = cookiecutter["name"].replace("_", " ").replace("-", " ").title()
97115
match cookiecutter["type"]:
98116
case "actor":
99117
return f"Frequenz {name} Actor"

cookiecutter/variable-reference.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
* `name`: The name of the project. This will be used to build defaults for
44
other inputs, such as `title`, `python_package`, etc. It should be one word,
5-
using only alphanumeric characters (and starting with a letter).
5+
using only alphanumeric characters (and starting with a letter). It can
6+
include also `_` and `-` which will be handled differently when building
7+
other variables from it (replaced by spaces in titles for example).
68

79
* `description`: A short description of the project. It will be used as the
810
description in the `README.md`, `pyproject.toml`, `mkdocs.yml`, etc.

cookiecutter/{{cookiecutter.github_repo_name}}/.github/labeler.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
# For example:
1111
#
1212
# "part:module":
13-
# - "src/frequenz/{{cookiecutter.type}}/{{cookiecutter.name}}/module/**"
13+
# - "src/frequenz/{{cookiecutter.type}}/{{cookiecutter.name | as_identifier}}/module/**"
1414
#
1515
# "part:other":
16-
# - "src/frequenz/{{cookiecutter.type}}/{{cookiecutter.name}}/other/**"
16+
# - "src/frequenz/{{cookiecutter.type}}/{{cookiecutter.name | as_identifier}}/other/**"
1717
#
1818
# # For excluding some files (in this example, label "part:complicated"
1919
# # everything inside src/ with a .py suffix, except for src/__init__.py)

cookiecutter/{{cookiecutter.github_repo_name}}/proto/frequenz/api/{{cookiecutter.name}}/{{cookiecutter.name}}.proto renamed to cookiecutter/{{cookiecutter.github_repo_name}}/proto/frequenz/api/{{cookiecutter.name | as_identifier}}/{{cookiecutter.name | as_identifier}}.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
syntax = "proto3";
1414

15-
package frequenz.api.{{cookiecutter.name}}.{{cookiecutter.name}};
15+
package frequenz.api.{{cookiecutter.name | as_identifier}}.{{cookiecutter.name | as_identifier}};
1616

1717
// An example message.
1818
//

cookiecutter/{{cookiecutter.github_repo_name}}/tests/test_{{cookiecutter.name}}.py renamed to cookiecutter/{{cookiecutter.github_repo_name}}/tests/test_{{cookiecutter.name | as_identifier}}.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,42 @@
11
# License: {{cookiecutter.license}}
22
# Copyright © {% now 'utc', '%Y' %} {{cookiecutter.author_name}}
33

4-
"""Tests for the {{cookiecutter.name}} package."""
4+
"""Tests for the {{cookiecutter.python_package}} package."""
55

66
{%- if cookiecutter.type == "api" %}
77

88

99
def test_package_import() -> None:
1010
"""Test that the package can be imported."""
1111
# pylint: disable=import-outside-toplevel
12-
from frequenz.api import {{cookiecutter.name}}
12+
from frequenz.api import {{cookiecutter.name | as_identifier}}
1313

14-
assert {{cookiecutter.name}} is not None
14+
assert {{cookiecutter.name | as_identifier}} is not None
1515

1616

1717
def test_module_import_components() -> None:
1818
"""Test that the modules can be imported."""
1919
# pylint: disable=import-outside-toplevel
20-
from frequenz.api.{{cookiecutter.name}} import {{cookiecutter.name}}_pb2
20+
from frequenz.api.{{cookiecutter.name | as_identifier}} import {{cookiecutter.name | as_identifier}}_pb2
2121

22-
assert {{cookiecutter.name}}_pb2 is not None
22+
assert {{cookiecutter.name | as_identifier}}_pb2 is not None
2323

2424
# pylint: disable=import-outside-toplevel
25-
from frequenz.api.{{cookiecutter.name}} import {{cookiecutter.name}}_pb2_grpc
25+
from frequenz.api.{{cookiecutter.name | as_identifier}} import {{cookiecutter.name | as_identifier}}_pb2_grpc
2626

27-
assert {{cookiecutter.name}}_pb2_grpc is not None
27+
assert {{cookiecutter.name | as_identifier}}_pb2_grpc is not None
2828
{%- else %}
2929
import pytest
3030

3131
from {{cookiecutter.python_package}} import delete_me
3232

3333

34-
def test_{{cookiecutter.name}}_succeeds() -> None: # TODO(cookiecutter): Remove
34+
def test_{{cookiecutter.name | as_identifier}}_succeeds() -> None: # TODO(cookiecutter): Remove
3535
"""Test that the delete_me function succeeds."""
3636
assert delete_me() is True
3737

3838

39-
def test_{{cookiecutter.name}}_fails() -> None: # TODO(cookiecutter): Remove
39+
def test_{{cookiecutter.name | as_identifier}}_fails() -> None: # TODO(cookiecutter): Remove
4040
"""Test that the delete_me function fails."""
4141
with pytest.raises(RuntimeError, match="This function should be removed!"):
4242
delete_me(blow_up=True)

0 commit comments

Comments
 (0)