Skip to content

Commit 5f8459b

Browse files
committed
DOP-1282: Make postprocessor nondestructive
1 parent bcde51e commit 5f8459b

File tree

3 files changed

+88
-75
lines changed

3 files changed

+88
-75
lines changed

snooty/parser.py

Lines changed: 79 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import collections
2-
import dataclasses
32
import docutils.nodes
43
import logging
54
import multiprocessing
@@ -12,7 +11,17 @@
1211
from dataclasses import dataclass
1312
from functools import partial
1413
from pathlib import Path, PurePath
15-
from typing import Any, Dict, MutableSequence, Tuple, Optional, Set, List, Iterable
14+
from typing import (
15+
Any,
16+
Callable,
17+
Dict,
18+
MutableSequence,
19+
Tuple,
20+
Optional,
21+
Set,
22+
List,
23+
Iterable,
24+
)
1625
from docutils.nodes import make_id
1726
from typing_extensions import Protocol
1827
import docutils.utils
@@ -808,6 +817,62 @@ def on_delete(self, page_id: FileId, build_identifiers: BuildIdentifierSet) -> N
808817
...
809818

810819

820+
class PageDatabase:
821+
"""A database of FileId->Page mappings that ensures the postprocessing pipeline
822+
is run correctly. Raw parsed pages are added, flush() is called, then postprocessed
823+
pages can be accessed."""
824+
825+
def __init__(self, postprocessor_factory: Callable[[], Postprocessor]) -> None:
826+
self.postprocessor_factory = postprocessor_factory
827+
self.parsed: Dict[FileId, Page] = {}
828+
self.__postprocessed: Dict[FileId, Page] = {}
829+
self.__changed_pages: Set[FileId] = set()
830+
831+
def __setitem__(self, key: FileId, value: Page) -> None:
832+
"""Set a raw parsed page."""
833+
self.parsed[key] = value
834+
self.__changed_pages.add(key)
835+
836+
def __getitem__(self, key: FileId) -> Page:
837+
"""If the postprocessor has been run since modifications were made, fetch a postprocessed page."""
838+
assert not self.__changed_pages
839+
return self.__postprocessed[key]
840+
841+
def __contains__(self, key: FileId) -> bool:
842+
"""Check if a given page exists in the parsed set."""
843+
return key in self.parsed
844+
845+
def values(self) -> Iterable[Page]:
846+
"""Iterate over postprocessed pages."""
847+
assert not self.__changed_pages
848+
return self.__postprocessed.values()
849+
850+
def items(self) -> Iterable[Tuple[FileId, Page]]:
851+
"""Iterate over the postprocessed (FileId, Page) set."""
852+
assert not self.__changed_pages
853+
return self.__postprocessed.items()
854+
855+
def flush(
856+
self
857+
) -> Tuple[Dict[str, SerializableType], Dict[FileId, List[Diagnostic]]]:
858+
"""Run the postprocessor if and only if any pages have changed, and return postprocessing results."""
859+
if not self.__changed_pages:
860+
return {}, {}
861+
862+
postprocessor = self.postprocessor_factory()
863+
864+
with util.PerformanceLogger.singleton().start("copy"):
865+
copied_pages = util.fast_deep_copy(self.parsed)
866+
867+
with util.PerformanceLogger.singleton().start("postprocessing"):
868+
post_metadata, post_diagnostics = postprocessor.run(copied_pages)
869+
870+
self.__postprocessed = postprocessor.pages
871+
self.__changed_pages.clear()
872+
873+
return post_metadata, post_diagnostics
874+
875+
811876
class _Project:
812877
"""Internal representation of a Snooty project with no data locking."""
813878

@@ -832,12 +897,6 @@ def __init__(
832897
self.filesystem_watcher = filesystem_watcher
833898
self.build_identifiers = build_identifiers
834899

835-
self.postprocessor = (
836-
DevhubPostprocessor(self.config, self.targets)
837-
if self.config.default_domain == "devhub"
838-
else Postprocessor(self.config, self.targets)
839-
)
840-
841900
self.yaml_mapping: Dict[str, GizaCategory[Any]] = {
842901
"steps": gizaparser.steps.GizaStepsCategory(self.config),
843902
"extracts": gizaparser.extracts.GizaExtractsCategory(self.config),
@@ -866,7 +925,11 @@ def __init__(
866925
).strip()
867926
self.prefix = [self.config.name, username, branch]
868927

869-
self.pages: Dict[FileId, Page] = {}
928+
self.pages = PageDatabase(
929+
lambda: DevhubPostprocessor(self.config, self.targets)
930+
if self.config.default_domain == "devhub"
931+
else Postprocessor(self.config, self.targets)
932+
)
870933

871934
self.asset_dg: "networkx.DiGraph[FileId]" = networkx.DiGraph()
872935
self.expensive_operation_cache: Cache[FileId] = Cache()
@@ -906,11 +969,13 @@ def get_page_ast(self, path: Path) -> n.Node:
906969
"""Update page file (.txt) with current text and return fully populated page AST"""
907970
# Get incomplete AST of page
908971
fileid = self.get_fileid(path)
972+
post_metadata, post_diagnostics = self.pages.flush()
973+
for fileid, diagnostics in post_diagnostics.items():
974+
self.backend.on_diagnostics(fileid, diagnostics)
909975
page = self.pages[fileid]
910976

911-
# Fill in missing include nodes
912977
assert isinstance(page.ast, n.Parent)
913-
return self._populate_include_nodes(page.ast)
978+
return page.ast
914979

915980
def get_project_name(self) -> str:
916981
return self.config.name
@@ -1042,8 +1107,7 @@ def create_page(filename: str) -> Tuple[Page, EmbeddedRstParser]:
10421107
page, all_yaml_diagnostics.get(page.source_path, [])
10431108
)
10441109

1045-
with util.PerformanceLogger.singleton().start("postprocessing"):
1046-
post_metadata, post_diagnostics = self.postprocessor.run(self.pages)
1110+
post_metadata, post_diagnostics = self.pages.flush()
10471111

10481112
static_files = {
10491113
"objects.inv": self.targets.generate_inventory("").dumps(
@@ -1052,7 +1116,7 @@ def create_page(filename: str) -> Tuple[Page, EmbeddedRstParser]:
10521116
}
10531117
post_metadata["static_files"] = static_files
10541118

1055-
for fileid, page in self.postprocessor.pages.items():
1119+
for fileid, page in self.pages.items():
10561120
self.backend.on_update(self.prefix, self.build_identifiers, fileid, page)
10571121
for fileid, diagnostics in post_diagnostics.items():
10581122
self.backend.on_diagnostics(fileid, diagnostics)
@@ -1061,62 +1125,6 @@ def create_page(filename: str) -> Tuple[Page, EmbeddedRstParser]:
10611125
self.prefix, self.build_identifiers, post_metadata
10621126
)
10631127

1064-
def _populate_include_nodes(self, root: n.Parent[n.Node]) -> n.Node:
1065-
"""
1066-
Add include nodes to page AST's children.
1067-
1068-
To render images on the Snooty extension's Snooty Preview,
1069-
we must use the full path of the image on the user's local machine. Note that this does change the
1070-
figure's value within the parser's dict. However, this should not change the value when using the parser
1071-
outside of Snooty Preview, since this function is currently only called by the language server.
1072-
"""
1073-
1074-
def replace_nodes(node: n.Node) -> n.Node:
1075-
if isinstance(node, n.Directive):
1076-
if node.name == "include":
1077-
# Get the name of the file
1078-
argument = node.argument[0]
1079-
include_filename = argument.value
1080-
include_filename = include_filename[1:]
1081-
1082-
# Get children of include file
1083-
include_file_page_ast = self.pages[FileId(include_filename)].ast
1084-
assert isinstance(include_file_page_ast, n.Parent)
1085-
include_node_children = include_file_page_ast.children
1086-
1087-
# Resolve includes within include node
1088-
replaced_include = list(map(replace_nodes, include_node_children))
1089-
node.children = replaced_include
1090-
# Replace instances of an image's name with its full path. This allows Snooty Preview to render an image by
1091-
# using the location of the image on the user's local machine
1092-
elif node.name == "figure":
1093-
# Obtain subset of the image's path (name)
1094-
argument = node.argument[0]
1095-
image_value = argument.value
1096-
1097-
# Prevents the image from having a redundant path if Snooty Preview already replaced
1098-
# its original value.
1099-
source_path_str = self.config.source_path.as_posix()
1100-
index_match = image_value.find(source_path_str)
1101-
if index_match != -1:
1102-
repeated_offset = index_match + len(source_path_str)
1103-
image_value = image_value[repeated_offset:]
1104-
1105-
# Replace subset of path with full path of image
1106-
if image_value[0] == "/":
1107-
image_value = image_value[1:]
1108-
full_path = self.get_full_path(FileId(image_value))
1109-
argument.value = full_path.as_posix()
1110-
# Check for include nodes among current node's children
1111-
elif isinstance(node, n.Parent):
1112-
for child in node.children:
1113-
replace_nodes(child)
1114-
return node
1115-
1116-
return dataclasses.replace(
1117-
root, children=list(map(replace_nodes, root.children))
1118-
)
1119-
11201128
def _page_updated(self, page: Page, diagnostics: List[Diagnostic]) -> None:
11211129
"""Update any state associated with a parsed page."""
11221130
# Finish any pending tasks
@@ -1130,7 +1138,7 @@ def _page_updated(self, page: Page, diagnostics: List[Diagnostic]) -> None:
11301138
logger.debug("Updated: %s", fileid)
11311139

11321140
if fileid in self.pages:
1133-
old_page = self.pages[fileid]
1141+
old_page = self.pages.parsed[fileid]
11341142
old_assets = old_page.static_assets
11351143
removed_assets = old_page.static_assets.difference(page.static_assets)
11361144

snooty/test_language_server.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,6 @@ def test_text_doc_get_page_ast() -> None:
140140

141141
source_path = server.project.config.source_path
142142

143-
# Image found in test file
144-
image_path = Path("images/compass-create-database.png")
145-
full_image_path = source_path.joinpath(image_path)
146143
# Change image path to be full path
147144
index_ast_string = (
148145
"""<root>
@@ -153,7 +150,7 @@ def test_text_doc_get_page_ast() -> None:
153150
<heading id="id1"><text>Guides</text></heading>
154151
<directive name="figure" alt="Sample images" checksum="10e351828f156afcafc7744c30d7b2564c6efba1ca7c55cac59560c67581f947">
155152
<text>"""
156-
+ full_image_path.as_posix()
153+
+ "/images/compass-create-database.png"
157154
+ """</text></directive>
158155
<directive name="include"><text>/includes/test_rst.rst</text>
159156
<directive name="include"><text>/includes/include_child.rst</text>

snooty/util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
import pickle
34
import time
45
import docutils.nodes
56
import docutils.parsers.rst.directives
@@ -27,6 +28,7 @@
2728
from . import n
2829

2930
logger = logging.getLogger(__name__)
31+
_T = TypeVar("_T")
3032
_K = TypeVar("_K", bound=Hashable)
3133
SOURCE_FILE_EXTENSIONS = {".txt", ".rst", ".yaml"}
3234
RST_EXTENSIONS = {".txt", ".rst"}
@@ -238,6 +240,12 @@ def split_domain(name: str) -> Tuple[str, str]:
238240
return parts[0], parts[1]
239241

240242

243+
def fast_deep_copy(d: Dict[_K, _T]) -> Dict[_K, _T]:
244+
"""Time-efficiently create deep copy of a dictionary containing trusted data.
245+
This implementation currently invokes pickle, so should NOT be called on untrusted objects."""
246+
return {k: pickle.loads(pickle.dumps(v)) for k, v in d.items()}
247+
248+
241249
class PerformanceLogger:
242250
_singleton: Optional["PerformanceLogger"] = None
243251

0 commit comments

Comments
 (0)