Skip to content

Commit 82642a8

Browse files
authored
Update the workspace content cache used during linting (#2953)
## Changes This PR began as an effort to retrofit the `WorkspaceCache` (and associated implementation) with type hints to help ensure correctness. This proved to be quite difficult due to the dynamic nature of the cache and binary vs. text modes. As such, this PR updates the `WorkspaceCache` in the following ways: - We now only cache the binary content of workspace files. This simplifies the implementation substantially and avoids correctness issues with respect to various options that can be specified to control decoding as binary content. (Decoding is well optimised compared to upstream/downstream processing, so this is a reasonable tradeoff.) - The cache now resolves paths, ensuring that different ways of referring to the same content are handled as a single cache entry. - We now raise an error immediately if a non-absolute path is requested: workspace paths are always absolute, so this will catch errors earlier. - For the paths where we check for a BOM-marker prior to reading as text, if the opened stream for the path is seekable then we only open it once instead of twice. ### Tests - added (and updated existing) unit tests - existing integration tests
1 parent c66893f commit 82642a8

File tree

7 files changed

+224
-148
lines changed

7 files changed

+224
-148
lines changed

src/databricks/labs/ucx/mixins/cached_workspace_path.py

Lines changed: 86 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -3,137 +3,136 @@
33
import os
44
from collections import OrderedDict
55
from collections.abc import Generator
6-
from io import StringIO, BytesIO
6+
from io import BytesIO
7+
from pathlib import PurePosixPath
8+
from typing import IO, TypeVar
79

810
from databricks.sdk import WorkspaceClient
911
from databricks.sdk.service.workspace import ObjectInfo
1012
from databricks.labs.blueprint.paths import WorkspacePath
1113

14+
from databricks.labs.ucx.source_code.base import decode_with_bom
1215

13-
class _CachedIO:
1416

15-
def __init__(self, content):
16-
self._content = content
17-
self._index = 0
18-
19-
def __enter__(self):
20-
return self
21-
22-
def __exit__(self, exc_type, exc_val, exc_tb):
23-
return False
17+
# lru_cache won't let us invalidate cache entries
18+
# so we provide our own custom lru_cache
19+
class _PathLruCache:
2420

25-
def read(self, *args, **_kwargs):
26-
count = -1 if len(args) < 1 or args[0] < 1 else args[0]
27-
if count == -1:
28-
return self._content
29-
start = self._index
30-
end = self._index + count
31-
if start >= len(self._content):
32-
return None
33-
self._index = self._index + count
34-
return self._content[start:end]
21+
_datas: OrderedDict[PurePosixPath, bytes]
22+
"""Cached binary data of files, keyed by workspace path, ordered from oldest to newest."""
3523

36-
def __iter__(self):
37-
if isinstance(self._content, str):
38-
yield from StringIO(self._content)
39-
return
40-
yield from self._as_string_io().__iter__()
24+
_max_entries: int
25+
"""The maximum number of entries to hold in the cache."""
4126

42-
def with_mode(self, mode: str):
43-
if 'b' in mode:
44-
return self._as_bytes_io()
45-
return self._as_string_io()
27+
def __init__(self, max_entries: int) -> None:
28+
# Ordered from oldest to newest.
29+
self._datas = OrderedDict()
30+
self._max_entries = max_entries
4631

47-
def _as_bytes_io(self):
48-
if isinstance(self._content, bytes):
49-
return self
50-
return BytesIO(self._content.encode("utf-8-sig"))
32+
@classmethod
33+
def _normalize(cls, path: _CachedPath) -> PurePosixPath:
34+
# Note: must not return the same instance that was passed in, to avoid circular references (and memory leaks).
35+
return PurePosixPath(*path.parts)
5136

52-
def _as_string_io(self):
53-
if isinstance(self._content, str):
54-
return self
55-
return StringIO(self._content.decode("utf-8"))
37+
def load(self, cached_path: _CachedPath, buffering: int = -1) -> bytes:
38+
normalized_path = self._normalize(cached_path)
5639

40+
data = self._datas.get(normalized_path, None)
41+
if data is not None:
42+
self._datas.move_to_end(normalized_path)
43+
return data
5744

58-
# lru_cache won't let us invalidate cache entries
59-
# so we provide our own custom lru_cache
60-
class _PathLruCache:
61-
62-
def __init__(self, max_entries: int):
63-
self._datas: OrderedDict[str, bytes | str] = OrderedDict()
64-
self._max_entries = max_entries
65-
66-
def open(self, cached_path: _CachedPath, mode, buffering, encoding, errors, newline):
67-
path = str(cached_path)
68-
if path in self._datas:
69-
self._datas.move_to_end(path)
70-
return _CachedIO(self._datas[path]).with_mode(mode)
71-
io_obj = WorkspacePath.open(cached_path, mode, buffering, encoding, errors, newline)
72-
# can't read twice from an IO so need to cache data rather than the io object
73-
data = io_obj.read()
74-
self._datas[path] = data
75-
result = _CachedIO(data).with_mode(mode)
76-
if len(self._datas) > self._max_entries:
45+
# Need to bypass the _CachedPath.open() override to actually open and retrieve the file content.
46+
with WorkspacePath.open(cached_path, mode="rb", buffering=buffering) as workspace_file:
47+
data = workspace_file.read()
48+
if self._max_entries <= len(self._datas):
7749
self._datas.popitem(last=False)
78-
return result
50+
self._datas[normalized_path] = data
51+
return data
7952

80-
def clear(self):
53+
def clear(self) -> None:
8154
self._datas.clear()
8255

83-
def remove(self, path: str):
84-
if path in self._datas:
85-
self._datas.pop(path)
56+
def remove(self, path: _CachedPath) -> None:
57+
del self._datas[self._normalize(path)]
8658

8759

8860
class _CachedPath(WorkspacePath):
89-
def __init__(self, cache: _PathLruCache, ws: WorkspaceClient, *args: str | bytes | os.PathLike):
61+
def __init__(self, cache: _PathLruCache, ws: WorkspaceClient, *args: str | bytes | os.PathLike) -> None:
9062
super().__init__(ws, *args)
9163
self._cache = cache
9264

93-
def with_object_info(self, object_info: ObjectInfo):
94-
self._cached_object_info = object_info
95-
return self
96-
97-
def with_segments(self, *path_segments: bytes | str | os.PathLike) -> _CachedPath:
65+
@classmethod
66+
def _from_object_info_with_cache(
67+
cls,
68+
cache: _PathLruCache,
69+
ws: WorkspaceClient,
70+
object_info: ObjectInfo,
71+
) -> _CachedPath:
72+
assert object_info.path
73+
path = cls(cache, ws, object_info.path)
74+
path._cached_object_info = object_info
75+
return path
76+
77+
def with_segments(self: _CachedPathT, *path_segments: bytes | str | os.PathLike) -> _CachedPathT:
9878
return type(self)(self._cache, self._ws, *path_segments)
9979

10080
def iterdir(self) -> Generator[_CachedPath, None, None]:
81+
# Variant of the superclass implementation that preserves the cache, as well as the client.
10182
for object_info in self._ws.workspace.list(self.as_posix()):
102-
path = object_info.path
103-
if path is None:
104-
msg = f"Cannot initialise without object path: {object_info}"
105-
raise ValueError(msg)
106-
child = _CachedPath(self._cache, self._ws, path)
107-
yield child.with_object_info(object_info)
108-
109-
def open(
83+
yield self._from_object_info_with_cache(self._cache, self._ws, object_info)
84+
85+
def open( # type: ignore[override]
11086
self,
11187
mode: str = "r",
11288
buffering: int = -1,
11389
encoding: str | None = None,
11490
errors: str | None = None,
11591
newline: str | None = None,
116-
):
117-
# only cache reads
118-
if 'r' in mode:
119-
return self._cache.open(self, mode, buffering, encoding, errors, newline)
120-
self._cache.remove(str(self))
121-
return super().open(mode, buffering, encoding, errors, newline)
92+
) -> IO:
93+
# We only cache reads; if a write happens we use the default implementation (and evict any cache entry).
94+
if 'w' in mode:
95+
self._cache.remove(self)
96+
return super().open(mode, buffering, encoding, errors, newline)
97+
98+
binary_data = self._cache.load(self, buffering=buffering)
99+
binary_io = BytesIO(binary_data)
100+
if 'b' in mode:
101+
return binary_io
122102

123-
def _cached_open(self, mode: str, buffering: int, encoding: str | None, errors: str | None, newline: str | None):
124-
return super().open(mode, buffering, encoding, errors, newline)
103+
return decode_with_bom(binary_io, encoding, errors, newline)
125104

126105
# _rename calls unlink so no need to override it
127106
def unlink(self, missing_ok: bool = False) -> None:
128-
self._cache.remove(str(self))
107+
self._cache.remove(self)
129108
return super().unlink(missing_ok)
130109

131110

111+
_CachedPathT = TypeVar("_CachedPathT", bound=_CachedPath)
112+
113+
132114
class WorkspaceCache:
133115

134-
def __init__(self, ws: WorkspaceClient, max_entries=2048):
116+
class InvalidWorkspacePath(ValueError):
117+
pass
118+
119+
def __init__(self, ws: WorkspaceClient, max_entries: int = 2048) -> None:
135120
self._ws = ws
136121
self._cache = _PathLruCache(max_entries)
137122

138-
def get_path(self, path: str):
123+
def get_workspace_path(self, path: str) -> WorkspacePath:
124+
"""Obtain a `WorkspacePath` instance for a path that refers to a workspace file or notebook.
125+
126+
The instance returned participates in this content cache: the first time the path is opened the content will
127+
be immediately retrieved (prior to reading) and cached.
128+
129+
Args:
130+
path: a valid workspace path (must be absolute)
131+
Raises:
132+
WorkspaceCache.InvalidWorkspacePath: this is raised immediately if the supplied path is not a syntactically
133+
valid workspace path. (This is not raised if the path is syntactically valid but does not exist.)
134+
"""
135+
if not path.startswith("/"):
136+
msg = f"Invalid workspace path; must be absolute and start with a slash ('/'): {path}"
137+
raise WorkspaceCache.InvalidWorkspacePath(msg)
139138
return _CachedPath(self._cache, self._ws, path)

src/databricks/labs/ucx/source_code/base.py

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
import codecs
44
import dataclasses
5-
import locale
5+
import io
66
import logging
77
import sys
88
from abc import abstractmethod, ABC
99
from collections.abc import Iterable
1010
from dataclasses import dataclass, field
1111
from datetime import datetime
1212
from pathlib import Path
13-
from typing import Any
13+
from typing import Any, BinaryIO, TextIO
1414

1515
from astroid import AstroidSyntaxError, NodeNG # type: ignore
1616
from sqlglot import Expression, parse as parse_sql, ParseError as SqlParseError
@@ -482,18 +482,71 @@ def file_language(path: Path) -> Language | None:
482482
return SUPPORTED_EXTENSION_LANGUAGES.get(path.suffix.lower())
483483

484484

485-
def guess_encoding(path: Path) -> str:
486-
# some files encode a unicode BOM (byte-order-mark), so let's use that if available
487-
with path.open('rb') as _file:
488-
raw = _file.read(4)
489-
if raw.startswith(codecs.BOM_UTF32_LE) or raw.startswith(codecs.BOM_UTF32_BE):
490-
return 'utf-32'
491-
if raw.startswith(codecs.BOM_UTF16_LE) or raw.startswith(codecs.BOM_UTF16_BE):
492-
return 'utf-16'
493-
if raw.startswith(codecs.BOM_UTF8):
494-
return 'utf-8-sig'
495-
# no BOM, let's use default encoding
496-
return locale.getpreferredencoding(False)
485+
def _detect_encoding_bom(binary_io: BinaryIO, *, preserve_position: bool) -> str | None:
486+
# Peek at the first (up to) 4 bytes, preserving the file position if requested.
487+
position = binary_io.tell() if preserve_position else None
488+
try:
489+
maybe_bom = binary_io.read(4)
490+
finally:
491+
if position is not None:
492+
binary_io.seek(position)
493+
# For these encodings, TextIOWrapper will skip over the BOM during decoding.
494+
if maybe_bom.startswith(codecs.BOM_UTF32_LE) or maybe_bom.startswith(codecs.BOM_UTF32_BE):
495+
return "utf-32"
496+
if maybe_bom.startswith(codecs.BOM_UTF16_LE) or maybe_bom.startswith(codecs.BOM_UTF16_BE):
497+
return "utf-16"
498+
if maybe_bom.startswith(codecs.BOM_UTF8):
499+
return "utf-8-sig"
500+
return None
501+
502+
503+
def decode_with_bom(
504+
file: BinaryIO,
505+
encoding: str | None = None,
506+
errors: str | None = None,
507+
newline: str | None = None,
508+
) -> TextIO:
509+
"""Wrap an open binary file with a text decoder.
510+
511+
This has the same semantics as the built-in `open()` call, except that if the encoding is not specified and the
512+
file is seekable then it will be checked for a BOM. If a BOM marker is found, that encoding is used. When neither
513+
an encoding nor a BOM are present the encoding of the system locale is used.
514+
515+
Args:
516+
file: the open (binary) file to wrap in text mode.
517+
encoding: force decoding with a specific locale. If not present the file BOM and system locale are used.
518+
errors: how decoding errors should be handled, as per open().
519+
newline: how newlines should be handled, as per open().
520+
Raises:
521+
ValueError: if the encoding should be detected via potential BOM marker but the file is not seekable.
522+
Returns:
523+
a text-based IO wrapper that will decode the underlying binary-mode file as text.
524+
"""
525+
use_encoding = _detect_encoding_bom(file, preserve_position=True) if encoding is None else encoding
526+
return io.TextIOWrapper(file, encoding=use_encoding, errors=errors, newline=newline)
527+
528+
529+
def read_text(path: Path, size: int = -1) -> str:
530+
"""Read a file as text, decoding according to the BOM marker if that is present.
531+
532+
This differs to the normal `.read_text()` method on path which does not support BOM markers.
533+
534+
Arguments:
535+
path: the path to read text from.
536+
size: how much text (measured in characters) to read. If negative, all text is read. Less may be read if the
537+
file is smaller than the specified size.
538+
Returns:
539+
The string content of the file, up to the specified size.
540+
"""
541+
with path.open("rb") as binary_io:
542+
# If the open file is seekable, we can detect the BOM and decode without re-opening.
543+
if binary_io.seekable():
544+
with decode_with_bom(binary_io) as f:
545+
return f.read(size)
546+
encoding = _detect_encoding_bom(binary_io, preserve_position=False)
547+
# Otherwise having read the BOM there's no way to rewind so we need to re-open and read from that.
548+
with path.open("rt", encoding=encoding) as f:
549+
return f.read(size)
497550

498551

499552
# duplicated from CellLanguage to prevent cyclic import
@@ -513,8 +566,7 @@ def is_a_notebook(path: Path, content: str | None = None) -> bool:
513566
if content is not None:
514567
return content.startswith(magic_header)
515568
try:
516-
with path.open('rt', encoding=guess_encoding(path)) as f:
517-
file_header = f.read(len(magic_header))
569+
file_header = read_text(path, size=len(magic_header))
518570
except (FileNotFoundError, UnicodeDecodeError, PermissionError):
519571
logger.warning(f"Could not read file {path}")
520572
return False

0 commit comments

Comments
 (0)