|
3 | 3 | import os |
4 | 4 | from collections import OrderedDict |
5 | 5 | from collections.abc import Generator |
6 | | -from io import StringIO, BytesIO |
| 6 | +from io import BytesIO |
| 7 | +from pathlib import PurePosixPath |
| 8 | +from typing import IO, TypeVar |
7 | 9 |
|
8 | 10 | from databricks.sdk import WorkspaceClient |
9 | 11 | from databricks.sdk.service.workspace import ObjectInfo |
10 | 12 | from databricks.labs.blueprint.paths import WorkspacePath |
11 | 13 |
|
| 14 | +from databricks.labs.ucx.source_code.base import decode_with_bom |
12 | 15 |
|
13 | | -class _CachedIO: |
14 | 16 |
|
15 | | - def __init__(self, content): |
16 | | - self._content = content |
17 | | - self._index = 0 |
18 | | - |
19 | | - def __enter__(self): |
20 | | - return self |
21 | | - |
22 | | - def __exit__(self, exc_type, exc_val, exc_tb): |
23 | | - return False |
| 17 | +# lru_cache won't let us invalidate cache entries |
| 18 | +# so we provide our own custom lru_cache |
| 19 | +class _PathLruCache: |
24 | 20 |
|
25 | | - def read(self, *args, **_kwargs): |
26 | | - count = -1 if len(args) < 1 or args[0] < 1 else args[0] |
27 | | - if count == -1: |
28 | | - return self._content |
29 | | - start = self._index |
30 | | - end = self._index + count |
31 | | - if start >= len(self._content): |
32 | | - return None |
33 | | - self._index = self._index + count |
34 | | - return self._content[start:end] |
| 21 | + _datas: OrderedDict[PurePosixPath, bytes] |
| 22 | + """Cached binary data of files, keyed by workspace path, ordered from oldest to newest.""" |
35 | 23 |
|
36 | | - def __iter__(self): |
37 | | - if isinstance(self._content, str): |
38 | | - yield from StringIO(self._content) |
39 | | - return |
40 | | - yield from self._as_string_io().__iter__() |
| 24 | + _max_entries: int |
| 25 | + """The maximum number of entries to hold in the cache.""" |
41 | 26 |
|
42 | | - def with_mode(self, mode: str): |
43 | | - if 'b' in mode: |
44 | | - return self._as_bytes_io() |
45 | | - return self._as_string_io() |
| 27 | + def __init__(self, max_entries: int) -> None: |
| 28 | + # Ordered from oldest to newest. |
| 29 | + self._datas = OrderedDict() |
| 30 | + self._max_entries = max_entries |
46 | 31 |
|
47 | | - def _as_bytes_io(self): |
48 | | - if isinstance(self._content, bytes): |
49 | | - return self |
50 | | - return BytesIO(self._content.encode("utf-8-sig")) |
| 32 | + @classmethod |
| 33 | + def _normalize(cls, path: _CachedPath) -> PurePosixPath: |
| 34 | + # Note: must not return the same instance that was passed in, to avoid circular references (and memory leaks). |
| 35 | + return PurePosixPath(*path.parts) |
51 | 36 |
|
52 | | - def _as_string_io(self): |
53 | | - if isinstance(self._content, str): |
54 | | - return self |
55 | | - return StringIO(self._content.decode("utf-8")) |
| 37 | + def load(self, cached_path: _CachedPath, buffering: int = -1) -> bytes: |
| 38 | + normalized_path = self._normalize(cached_path) |
56 | 39 |
|
| 40 | + data = self._datas.get(normalized_path, None) |
| 41 | + if data is not None: |
| 42 | + self._datas.move_to_end(normalized_path) |
| 43 | + return data |
57 | 44 |
|
58 | | -# lru_cache won't let us invalidate cache entries |
59 | | -# so we provide our own custom lru_cache |
60 | | -class _PathLruCache: |
61 | | - |
62 | | - def __init__(self, max_entries: int): |
63 | | - self._datas: OrderedDict[str, bytes | str] = OrderedDict() |
64 | | - self._max_entries = max_entries |
65 | | - |
66 | | - def open(self, cached_path: _CachedPath, mode, buffering, encoding, errors, newline): |
67 | | - path = str(cached_path) |
68 | | - if path in self._datas: |
69 | | - self._datas.move_to_end(path) |
70 | | - return _CachedIO(self._datas[path]).with_mode(mode) |
71 | | - io_obj = WorkspacePath.open(cached_path, mode, buffering, encoding, errors, newline) |
72 | | - # can't read twice from an IO so need to cache data rather than the io object |
73 | | - data = io_obj.read() |
74 | | - self._datas[path] = data |
75 | | - result = _CachedIO(data).with_mode(mode) |
76 | | - if len(self._datas) > self._max_entries: |
| 45 | + # Need to bypass the _CachedPath.open() override to actually open and retrieve the file content. |
| 46 | + with WorkspacePath.open(cached_path, mode="rb", buffering=buffering) as workspace_file: |
| 47 | + data = workspace_file.read() |
| 48 | + if self._max_entries <= len(self._datas): |
77 | 49 | self._datas.popitem(last=False) |
78 | | - return result |
| 50 | + self._datas[normalized_path] = data |
| 51 | + return data |
79 | 52 |
|
80 | | - def clear(self): |
| 53 | + def clear(self) -> None: |
81 | 54 | self._datas.clear() |
82 | 55 |
|
83 | | - def remove(self, path: str): |
84 | | - if path in self._datas: |
85 | | - self._datas.pop(path) |
| 56 | + def remove(self, path: _CachedPath) -> None: |
| 57 | + del self._datas[self._normalize(path)] |
86 | 58 |
|
87 | 59 |
|
88 | 60 | class _CachedPath(WorkspacePath): |
89 | | - def __init__(self, cache: _PathLruCache, ws: WorkspaceClient, *args: str | bytes | os.PathLike): |
| 61 | + def __init__(self, cache: _PathLruCache, ws: WorkspaceClient, *args: str | bytes | os.PathLike) -> None: |
90 | 62 | super().__init__(ws, *args) |
91 | 63 | self._cache = cache |
92 | 64 |
|
93 | | - def with_object_info(self, object_info: ObjectInfo): |
94 | | - self._cached_object_info = object_info |
95 | | - return self |
96 | | - |
97 | | - def with_segments(self, *path_segments: bytes | str | os.PathLike) -> _CachedPath: |
| 65 | + @classmethod |
| 66 | + def _from_object_info_with_cache( |
| 67 | + cls, |
| 68 | + cache: _PathLruCache, |
| 69 | + ws: WorkspaceClient, |
| 70 | + object_info: ObjectInfo, |
| 71 | + ) -> _CachedPath: |
| 72 | + assert object_info.path |
| 73 | + path = cls(cache, ws, object_info.path) |
| 74 | + path._cached_object_info = object_info |
| 75 | + return path |
| 76 | + |
| 77 | + def with_segments(self: _CachedPathT, *path_segments: bytes | str | os.PathLike) -> _CachedPathT: |
98 | 78 | return type(self)(self._cache, self._ws, *path_segments) |
99 | 79 |
|
100 | 80 | def iterdir(self) -> Generator[_CachedPath, None, None]: |
| 81 | + # Variant of the superclass implementation that preserves the cache, as well as the client. |
101 | 82 | for object_info in self._ws.workspace.list(self.as_posix()): |
102 | | - path = object_info.path |
103 | | - if path is None: |
104 | | - msg = f"Cannot initialise without object path: {object_info}" |
105 | | - raise ValueError(msg) |
106 | | - child = _CachedPath(self._cache, self._ws, path) |
107 | | - yield child.with_object_info(object_info) |
108 | | - |
109 | | - def open( |
| 83 | + yield self._from_object_info_with_cache(self._cache, self._ws, object_info) |
| 84 | + |
| 85 | + def open( # type: ignore[override] |
110 | 86 | self, |
111 | 87 | mode: str = "r", |
112 | 88 | buffering: int = -1, |
113 | 89 | encoding: str | None = None, |
114 | 90 | errors: str | None = None, |
115 | 91 | newline: str | None = None, |
116 | | - ): |
117 | | - # only cache reads |
118 | | - if 'r' in mode: |
119 | | - return self._cache.open(self, mode, buffering, encoding, errors, newline) |
120 | | - self._cache.remove(str(self)) |
121 | | - return super().open(mode, buffering, encoding, errors, newline) |
| 92 | + ) -> IO: |
| 93 | + # We only cache reads; if a write happens we use the default implementation (and evict any cache entry). |
| 94 | + if 'w' in mode: |
| 95 | + self._cache.remove(self) |
| 96 | + return super().open(mode, buffering, encoding, errors, newline) |
| 97 | + |
| 98 | + binary_data = self._cache.load(self, buffering=buffering) |
| 99 | + binary_io = BytesIO(binary_data) |
| 100 | + if 'b' in mode: |
| 101 | + return binary_io |
122 | 102 |
|
123 | | - def _cached_open(self, mode: str, buffering: int, encoding: str | None, errors: str | None, newline: str | None): |
124 | | - return super().open(mode, buffering, encoding, errors, newline) |
| 103 | + return decode_with_bom(binary_io, encoding, errors, newline) |
125 | 104 |
|
126 | 105 | # _rename calls unlink so no need to override it |
127 | 106 | def unlink(self, missing_ok: bool = False) -> None: |
128 | | - self._cache.remove(str(self)) |
| 107 | + self._cache.remove(self) |
129 | 108 | return super().unlink(missing_ok) |
130 | 109 |
|
131 | 110 |
|
| 111 | +_CachedPathT = TypeVar("_CachedPathT", bound=_CachedPath) |
| 112 | + |
| 113 | + |
132 | 114 | class WorkspaceCache: |
133 | 115 |
|
134 | | - def __init__(self, ws: WorkspaceClient, max_entries=2048): |
| 116 | + class InvalidWorkspacePath(ValueError): |
| 117 | + pass |
| 118 | + |
| 119 | + def __init__(self, ws: WorkspaceClient, max_entries: int = 2048) -> None: |
135 | 120 | self._ws = ws |
136 | 121 | self._cache = _PathLruCache(max_entries) |
137 | 122 |
|
138 | | - def get_path(self, path: str): |
| 123 | + def get_workspace_path(self, path: str) -> WorkspacePath: |
| 124 | + """Obtain a `WorkspacePath` instance for a path that refers to a workspace file or notebook. |
| 125 | +
|
| 126 | + The instance returned participates in this content cache: the first time the path is opened the content will |
| 127 | + be immediately retrieved (prior to reading) and cached. |
| 128 | +
|
| 129 | + Args: |
| 130 | + path: a valid workspace path (must be absolute) |
| 131 | + Raises: |
| 132 | + WorkspaceCache.InvalidWorkspacePath: this is raised immediately if the supplied path is not a syntactically |
| 133 | + valid workspace path. (This is not raised if the path is syntactically valid but does not exist.) |
| 134 | + """ |
| 135 | + if not path.startswith("/"): |
| 136 | + msg = f"Invalid workspace path; must be absolute and start with a slash ('/'): {path}" |
| 137 | + raise WorkspaceCache.InvalidWorkspacePath(msg) |
139 | 138 | return _CachedPath(self._cache, self._ws, path) |
0 commit comments