Skip to content

Commit ed6dc16

Browse files
authored
DOP-5507: allow merging of rstspec and snooty configs for composables (#660)
* init commit. merge spec and snooty config for composables * add diagnostics * make spec parser intake rst spec as well * log error from snooty toml * allow writers to dictate dependencies in total * return tuple * fix keyerror * address comments * merge fix
1 parent d2eb3d7 commit ed6dc16

File tree

5 files changed

+166
-25
lines changed

5 files changed

+166
-25
lines changed

snooty/main.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@
4040
from .page import Page
4141
from .parser import Project, ProjectBackend, ProjectLoadError
4242
from .types import BuildIdentifierSet, ProjectConfig
43-
from .util import EXT_FOR_PAGE, SOURCE_FILE_EXTENSIONS, HTTPCache, PerformanceLogger
43+
from .util import (
44+
EXT_FOR_PAGE,
45+
SNOOTY_TOML,
46+
SOURCE_FILE_EXTENSIONS,
47+
HTTPCache,
48+
PerformanceLogger,
49+
)
4450

4551
PARANOID_MODE = os.environ.get("SNOOTY_PARANOID", "0") == "1"
4652
PATTERNS = ["*" + ext for ext in SOURCE_FILE_EXTENSIONS]
@@ -256,15 +262,6 @@ def main() -> None:
256262

257263
logger.info(f"Snooty {__version__} starting")
258264

259-
if args["--rstspec"]:
260-
rstspec_path = args["--rstspec"]
261-
if rstspec_path.startswith("https://") or rstspec_path.startswith("http://"):
262-
rstspec_bytes = HTTPCache.singleton().get(args["--rstspec"])
263-
rstspec_text = str(rstspec_bytes, "utf-8")
264-
else:
265-
rstspec_text = Path(rstspec_path).expanduser().read_text(encoding="utf-8")
266-
specparser.Spec.initialize(rstspec_text)
267-
268265
if PARANOID_MODE:
269266
logger.info("Paranoid mode on")
270267

@@ -283,6 +280,17 @@ def main() -> None:
283280
assert args["<source-path>"] is not None
284281
root_path = Path(args["<source-path>"])
285282

283+
if args["--rstspec"]:
284+
rstspec_path = args["--rstspec"]
285+
if rstspec_path.startswith("https://") or rstspec_path.startswith("http://"):
286+
rstspec_bytes = HTTPCache.singleton().get(args["--rstspec"])
287+
rstspec_text = str(rstspec_bytes, "utf-8")
288+
else:
289+
rstspec_text = Path(rstspec_path).expanduser().read_text(encoding="utf-8")
290+
specparser.Spec.initialize(
291+
rstspec_text, Path.joinpath(root_path, Path(SNOOTY_TOML))
292+
)
293+
286294
branch = args["--branch"]
287295

288296
try:

snooty/parser.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,9 @@ def handle_tabset(self, node: n.Directive) -> None:
615615
line = node.start[0]
616616
# retrieve dictionary associated with this specific tabset
617617
try:
618-
tab_definitions_list = specparser.Spec.get().tabs[tabset]
618+
tab_definitions_list = specparser.Spec.get(
619+
self.project_config.config_path
620+
).tabs[tabset]
619621
except KeyError:
620622
self.diagnostics.append(UnknownTabset(tabset, line))
621623
return
@@ -675,7 +677,9 @@ def handle_tabset(self, node: n.Directive) -> None:
675677
)
676678

677679
def handle_wayfinding(self, node: n.Directive) -> None:
678-
expected_options = specparser.Spec.get().wayfinding["options"]
680+
expected_options = specparser.Spec.get(
681+
self.project_config.config_path
682+
).wayfinding["options"]
679683
expected_options_dict = {option.id: option for option in expected_options}
680684
expected_child_opt_name = "wayfinding-option"
681685
expected_child_desc_name = "wayfinding-description"
@@ -797,7 +801,9 @@ def check_valid_option_id(
797801
raise ChildValidationError()
798802

799803
def handle_method_selector(self, node: n.Directive) -> None:
800-
expected_options = specparser.Spec.get().method_selector["options"]
804+
expected_options = specparser.Spec.get(
805+
self.project_config.config_path
806+
).method_selector["options"]
801807
expected_options_dict = {option.id: option for option in expected_options}
802808
expected_child_name = "method-option"
803809

@@ -885,7 +891,9 @@ def handle_composable(self, node: n.ComposableDirective) -> None:
885891
)
886892

887893
# get the expected composable options from the spec
888-
spec_composables = specparser.Spec.get().composables
894+
spec_composables = specparser.Spec.get(
895+
self.project_config.config_path
896+
).composables
889897
spec_composables_dict = {
890898
expected_option.id: expected_option for expected_option in spec_composables
891899
}

snooty/postprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1697,7 +1697,7 @@ def __init__(self, context: Context) -> None:
16971697
super().__init__(context)
16981698
self.project_config = context[ProjectConfig]
16991699
self.targets = context[TargetDatabase]
1700-
self.spec = specparser.Spec.get()
1700+
self.spec = specparser.Spec.get(self.project_config.config_path)
17011701

17021702
def enter_node(self, fileid_stack: FileIdStack, node: n.Node) -> None:
17031703
"""When a node of type ref_role is encountered, ensure that it references a valid target.

snooty/specparser.py

Lines changed: 134 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
"""Parser for a TOML spec file containing definitions of all supported reStructuredText
2-
directives and roles, and what types of data each should expect."""
2+
directives and roles, and what types of data each should expect."""
33

44
from __future__ import annotations
55

66
import dataclasses
7+
import logging
78
from dataclasses import dataclass, field
89
from datetime import datetime
910
from enum import Enum
11+
from pathlib import Path
1012
from typing import (
1113
Any,
1214
Callable,
@@ -18,13 +20,16 @@
1820
Optional,
1921
Sequence,
2022
Set,
23+
Tuple,
2124
TypeVar,
2225
Union,
2326
)
2427

2528
import tomli
2629
from typing_extensions import Protocol
2730

31+
from snooty.diagnostics import Diagnostic, UnknownOptionId
32+
2833
from . import tinydocutils, util
2934
from .flutter import check_type, checked
3035

@@ -70,6 +75,8 @@
7075
PrimitiveType.linenos: util.option_string,
7176
}
7277

78+
logger = logging.getLogger(__name__)
79+
7380
#: Option types can be a primitive type (PrimitiveType), an enum
7481
#: defined in the spec, or a union of those.
7582
ArgumentType = Union[List[Union[PrimitiveType, str]], PrimitiveType, str, None]
@@ -308,6 +315,7 @@ class Spec:
308315
wayfinding: Dict[str, List[WayfindingOption]] = field(default_factory=dict)
309316
data_fields: List[str] = field(default_factory=list)
310317
composables: List[Composable] = field(default_factory=list)
318+
merged: bool = False
311319

312320
SPEC: ClassVar[Optional[Spec]] = None
313321

@@ -439,15 +447,131 @@ def resolve_value(key: str, inheritable: _T) -> _T:
439447
resolve_value(key, inheritable)
440448

441449
@classmethod
442-
def initialize(cls, text: str) -> None:
450+
def _merge_composables(
451+
cls, spec: Spec, custom_composables: List[Dict[str, Any]]
452+
) -> Tuple[Spec, List[Diagnostic]]:
453+
res: List[Composable] = []
454+
diagnostics: List[Diagnostic] = []
455+
456+
custom_composable_by_id = {
457+
composable["id"]: composable for composable in custom_composables
458+
}
459+
460+
for defined_composable in spec.composables:
461+
custom_composable = custom_composable_by_id.pop(defined_composable.id, None)
462+
if not custom_composable:
463+
res.append(defined_composable)
464+
continue
465+
merged_title = custom_composable["title"]
466+
467+
# merge all the options
468+
defined_options = {
469+
option.id: option for option in defined_composable.options
470+
}
471+
custom_options = {
472+
option["id"]: option for option in custom_composable["options"]
473+
}
474+
475+
merged_options = []
476+
for option_id in set(defined_options.keys()) | set(custom_options.keys()):
477+
if option_id in custom_options:
478+
custom_option = custom_options[option_id]
479+
merged_options.append(
480+
TabDefinition(custom_option["id"], custom_option["title"])
481+
)
482+
else:
483+
merged_options.append(defined_options[option_id])
484+
485+
merged_dependencies = (
486+
custom_composable["dependencies"]
487+
if "dependencies" in custom_composable
488+
else defined_composable.dependencies
489+
)
490+
491+
merged_default = (
492+
custom_composable["default"]
493+
if "default" in custom_composable
494+
else defined_composable.default
495+
)
496+
default_option = next(
497+
(
498+
option
499+
for option in merged_options
500+
if merged_default and option.id == merged_default
501+
),
502+
None,
503+
)
504+
if merged_default and not default_option:
505+
diagnostics.append(
506+
UnknownOptionId(
507+
"Spec composables default",
508+
merged_default,
509+
[option.title for option in merged_options],
510+
0,
511+
)
512+
)
513+
res.append(
514+
Composable(
515+
defined_composable.id,
516+
merged_title,
517+
merged_default,
518+
merged_dependencies,
519+
merged_options,
520+
)
521+
)
522+
523+
for composable_obj in custom_composable_by_id.values():
524+
res.append(
525+
Composable(
526+
composable_obj["id"],
527+
composable_obj["title"],
528+
composable_obj["default"] if "default" in composable_obj else None,
529+
(
530+
composable_obj["dependencies"]
531+
if "dependencies" in composable_obj
532+
else None
533+
),
534+
list(
535+
map(
536+
lambda option: TabDefinition(option["id"], option["title"]),
537+
composable_obj["options"],
538+
)
539+
),
540+
)
541+
)
542+
543+
spec.composables = res
544+
return (spec, diagnostics)
545+
546+
@classmethod
547+
def initialize(cls, text: str, configPath: Optional[Path]) -> "Spec":
443548
cls.SPEC = Spec.loads(text)
549+
if configPath:
550+
project_config = tomli.loads(configPath.read_text(encoding="utf-8"))
551+
# NOTE: would like to check_type but circular imports
552+
# this is already verified earlier in the process
553+
spec, diagnostics = cls._merge_composables(
554+
cls.SPEC,
555+
(
556+
project_config["composables"]
557+
if "composables" in project_config
558+
else []
559+
),
560+
)
561+
spec.merged = True
562+
cls.SPEC = spec
563+
for diagnostic in diagnostics:
564+
logger.error(diagnostic)
565+
return cls.SPEC
444566

445567
@classmethod
446-
def get(cls) -> "Spec":
447-
if cls.SPEC is None:
448-
path = util.PACKAGE_ROOT.joinpath("rstspec.toml")
449-
cls.initialize(path.read_text(encoding="utf-8"))
450-
451-
spec = cls.SPEC
452-
assert spec is not None
453-
return spec
568+
def get(cls, configPath: Optional[Path] = None) -> "Spec":
569+
if cls.SPEC and cls.SPEC.merged:
570+
return cls.SPEC
571+
572+
path = util.PACKAGE_ROOT.joinpath("rstspec.toml")
573+
spec = cls.initialize(path.read_text(encoding="utf-8"), configPath)
574+
575+
cls.SPEC = spec
576+
assert cls.SPEC is not None
577+
return cls.SPEC

snooty/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ class ProjectConfig:
218218
bundle: BundleConfig = field(default_factory=BundleConfig)
219219
data: Dict[str, object] = field(default_factory=dict)
220220
associated_products: List[AssociatedProduct] = field(default_factory=list)
221+
composables: List[specparser.Composable] = field(default_factory=list)
221222

222223
# banner_nodes contains parsed banner nodes with target data
223224
banner_nodes: List[ParsedBannerConfig] = field(

0 commit comments

Comments
 (0)