Skip to content

Commit fc554ae

Browse files
authored
Cache workspace content (#2497)
## Changes Loading workspace content is slow and bound by rate limits. This PR introduces a cache for workspace content. ### Linked issues None ### Functionality None ### Tests - [x] added unit tests --------- Co-authored-by: Eric Vergnaud <[email protected]>
1 parent e5e0562 commit fc554ae

File tree

6 files changed

+261
-5
lines changed

6 files changed

+261
-5
lines changed
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from collections import OrderedDict
5+
from collections.abc import Generator
6+
from io import StringIO, BytesIO
7+
8+
from databricks.sdk import WorkspaceClient
9+
from databricks.sdk.service.workspace import ObjectInfo
10+
from databricks.labs.blueprint.paths import WorkspacePath
11+
12+
13+
class _CachedIO:
14+
15+
def __init__(self, content):
16+
self._content = content
17+
self._index = 0
18+
19+
def __enter__(self):
20+
return self
21+
22+
def __exit__(self, exc_type, exc_val, exc_tb):
23+
return False
24+
25+
def read(self, *args, **_kwargs):
26+
count = -1 if len(args) < 1 or args[0] < 1 else args[0]
27+
if count == -1:
28+
return self._content
29+
start = self._index
30+
end = self._index + count
31+
if start >= len(self._content):
32+
return None
33+
self._index = self._index + count
34+
return self._content[start:end]
35+
36+
def __iter__(self):
37+
if isinstance(self._content, str):
38+
yield from StringIO(self._content)
39+
return
40+
yield from self._as_string_io().__iter__()
41+
42+
def with_mode(self, mode: str):
43+
if 'b' in mode:
44+
return self._as_bytes_io()
45+
return self._as_string_io()
46+
47+
def _as_bytes_io(self):
48+
if isinstance(self._content, bytes):
49+
return self
50+
return BytesIO(self._content.encode("utf-8-sig"))
51+
52+
def _as_string_io(self):
53+
if isinstance(self._content, str):
54+
return self
55+
return StringIO(self._content.decode("utf-8"))
56+
57+
58+
# lru_cache won't let us invalidate cache entries
59+
# so we provide our own custom lru_cache
60+
class _PathLruCache:
61+
62+
def __init__(self, max_entries: int):
63+
self._datas: OrderedDict[str, bytes | str] = OrderedDict()
64+
self._max_entries = max_entries
65+
66+
def open(self, cached_path: _CachedPath, mode, buffering, encoding, errors, newline):
67+
path = str(cached_path)
68+
if path in self._datas:
69+
self._datas.move_to_end(path)
70+
return _CachedIO(self._datas[path]).with_mode(mode)
71+
io_obj = WorkspacePath.open(cached_path, mode, buffering, encoding, errors, newline)
72+
# can't read twice from an IO so need to cache data rather than the io object
73+
data = io_obj.read()
74+
self._datas[path] = data
75+
result = _CachedIO(data).with_mode(mode)
76+
if len(self._datas) > self._max_entries:
77+
self._datas.popitem(last=False)
78+
return result
79+
80+
def clear(self):
81+
self._datas.clear()
82+
83+
def remove(self, path: str):
84+
if path in self._datas:
85+
self._datas.pop(path)
86+
87+
88+
class _CachedPath(WorkspacePath):
89+
def __init__(self, cache: _PathLruCache, ws: WorkspaceClient, *args: str | bytes | os.PathLike):
90+
super().__init__(ws, *args)
91+
self._cache = cache
92+
93+
def with_object_info(self, object_info: ObjectInfo):
94+
self._cached_object_info = object_info
95+
return self
96+
97+
def with_segments(self, *path_segments: bytes | str | os.PathLike) -> _CachedPath:
98+
return type(self)(self._cache, self._ws, *path_segments)
99+
100+
def iterdir(self) -> Generator[_CachedPath, None, None]:
101+
for object_info in self._ws.workspace.list(self.as_posix()):
102+
path = object_info.path
103+
if path is None:
104+
msg = f"Cannot initialise without object path: {object_info}"
105+
raise ValueError(msg)
106+
child = _CachedPath(self._cache, self._ws, path)
107+
yield child.with_object_info(object_info)
108+
109+
def open(
110+
self,
111+
mode: str = "r",
112+
buffering: int = -1,
113+
encoding: str | None = None,
114+
errors: str | None = None,
115+
newline: str | None = None,
116+
):
117+
# only cache reads
118+
if 'r' in mode:
119+
return self._cache.open(self, mode, buffering, encoding, errors, newline)
120+
self._cache.remove(str(self))
121+
return super().open(mode, buffering, encoding, errors, newline)
122+
123+
def _cached_open(self, mode: str, buffering: int, encoding: str | None, errors: str | None, newline: str | None):
124+
return super().open(mode, buffering, encoding, errors, newline)
125+
126+
# _rename calls unlink so no need to override it
127+
def unlink(self, missing_ok: bool = False) -> None:
128+
self._cache.remove(str(self))
129+
return super().unlink(missing_ok)
130+
131+
132+
class WorkspaceCache:
133+
134+
def __init__(self, ws: WorkspaceClient, max_entries=2048):
135+
self._ws = ws
136+
self._cache = _PathLruCache(max_entries)
137+
138+
def get_path(self, path: str):
139+
return _CachedPath(self._cache, self._ws, path)

src/databricks/labs/ucx/source_code/jobs.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
from urllib import parse
1212

1313
from databricks.labs.blueprint.parallel import ManyError, Threads
14-
from databricks.labs.blueprint.paths import DBFSPath, WorkspacePath
14+
from databricks.labs.blueprint.paths import DBFSPath
1515
from databricks.labs.lsql.backends import SqlBackend
1616
from databricks.sdk import WorkspaceClient
1717
from databricks.sdk.errors import NotFound
1818
from databricks.sdk.service import compute, jobs
1919

2020
from databricks.labs.ucx.assessment.crawlers import runtime_version_tuple
2121
from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex
22+
from databricks.labs.ucx.mixins.cached_workspace_path import WorkspaceCache
2223
from databricks.labs.ucx.source_code.base import CurrentSessionState, is_a_notebook, LocatedAdvice
2324
from databricks.labs.ucx.source_code.graph import (
2425
Dependency,
@@ -72,6 +73,7 @@ def __init__(self, ws: WorkspaceClient, task: jobs.Task, job: jobs.Job):
7273
self._task = task
7374
self._job = job
7475
self._ws = ws
76+
self._cache = WorkspaceCache(ws)
7577
self._named_parameters: dict[str, str] | None = {}
7678
self._parameters: list[str] | None = []
7779
self._spark_conf: dict[str, str] | None = {}
@@ -123,7 +125,7 @@ def _as_path(self, path: str) -> Path:
123125
parsed_path = parse.urlparse(path)
124126
match parsed_path.scheme:
125127
case "":
126-
return WorkspacePath(self._ws, path)
128+
return self._cache.get_path(path)
127129
case "dbfs":
128130
return DBFSPath(self._ws, parsed_path.path)
129131
case other:
@@ -186,7 +188,7 @@ def _register_notebook(self, graph: DependencyGraph) -> Iterable[DependencyProbl
186188
notebook_path = self._task.notebook_task.notebook_path
187189
logger.info(f'Discovering {self._task.task_key} entrypoint: {notebook_path}')
188190
# Notebooks can't be on DBFS.
189-
path = WorkspacePath(self._ws, notebook_path)
191+
path = self._cache.get_path(notebook_path)
190192
return graph.register_notebook(path, False)
191193

192194
def _register_spark_python_task(self, graph: DependencyGraph):
@@ -261,7 +263,7 @@ def _register_pipeline_task(self, graph: DependencyGraph):
261263
if library.notebook.path:
262264
notebook_path = library.notebook.path
263265
# Notebooks can't be on DBFS.
264-
path = WorkspacePath(self._ws, notebook_path)
266+
path = self._cache.get_path(notebook_path)
265267
# the notebook is the root of the graph, so there's no context to inherit
266268
yield from graph.register_notebook(path, inherit_context=False)
267269
if library.jar:

tests/unit/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,9 @@ def mock_workspace_client(
201201
]
202202
),
203203
}
204-
ws.workspace.download.side_effect = lambda file_name: io.StringIO(download_yaml[os.path.basename(file_name)])
204+
ws.workspace.download.side_effect = lambda file_name, *, format=None: io.StringIO(
205+
download_yaml[os.path.basename(file_name)]
206+
)
205207
return ws
206208

207209

tests/unit/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88
from databricks.labs.blueprint.installation import MockInstallation
99
from databricks.labs.lsql.backends import MockBackend
10+
1011
from databricks.labs.ucx.source_code.graph import BaseNotebookResolver
1112
from databricks.labs.ucx.source_code.path_lookup import PathLookup
1213
from databricks.sdk import WorkspaceClient, AccountClient

tests/unit/mixins/__init__.py

Whitespace-only changes.
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import io
2+
from unittest.mock import create_autospec
3+
4+
import pytest
5+
6+
from tests.unit import mock_workspace_client
7+
8+
from databricks.sdk import WorkspaceClient
9+
from databricks.sdk.service.workspace import ObjectInfo, ObjectType
10+
11+
from databricks.labs.ucx.mixins.cached_workspace_path import WorkspaceCache
12+
from databricks.labs.ucx.source_code.base import guess_encoding
13+
14+
15+
class TestWorkspaceCache(WorkspaceCache):
16+
17+
@property
18+
def data_cache(self):
19+
return self._cache
20+
21+
22+
def test_path_like_returns_cached_instance():
23+
cache = TestWorkspaceCache(mock_workspace_client())
24+
parent = cache.get_path("path")
25+
child = parent / "child"
26+
_cache = getattr(child, "_cache")
27+
assert _cache == cache.data_cache
28+
29+
30+
def test_iterdir_returns_cached_instances():
31+
ws = create_autospec(WorkspaceClient)
32+
ws.workspace.get_status.return_value = ObjectInfo(object_type=ObjectType.DIRECTORY)
33+
ws.workspace.list.return_value = list(ObjectInfo(object_type=ObjectType.FILE, path=s) for s in ("a", "b", "c"))
34+
cache = TestWorkspaceCache(ws)
35+
parent = cache.get_path("dir")
36+
assert parent.is_dir()
37+
for child in parent.iterdir():
38+
_cache = getattr(child, "_cache")
39+
assert _cache == cache.data_cache
40+
41+
42+
def test_download_is_only_called_once_per_instance():
43+
ws = mock_workspace_client()
44+
ws.workspace.download.side_effect = lambda _, *, format: io.BytesIO("abc".encode())
45+
cache = WorkspaceCache(ws)
46+
path = cache.get_path("path")
47+
for _ in range(0, 4):
48+
_ = path.read_text()
49+
assert ws.workspace.download.call_count == 1
50+
51+
52+
def test_download_is_only_called_once_across_instances():
53+
ws = mock_workspace_client()
54+
ws.workspace.download.side_effect = lambda _, *, format: io.BytesIO("abc".encode())
55+
cache = WorkspaceCache(ws)
56+
for _ in range(0, 4):
57+
path = cache.get_path("path")
58+
_ = path.read_text()
59+
assert ws.workspace.download.call_count == 1
60+
61+
62+
def test_download_is_called_again_after_unlink():
63+
ws = mock_workspace_client()
64+
ws.workspace.download.side_effect = lambda _, *, format: io.BytesIO("abc".encode())
65+
cache = WorkspaceCache(ws)
66+
path = cache.get_path("path")
67+
_ = path.read_text()
68+
path = cache.get_path("path")
69+
path.unlink()
70+
_ = path.read_text()
71+
assert ws.workspace.download.call_count == 2
72+
73+
74+
def test_download_is_called_again_after_rename():
75+
ws = mock_workspace_client()
76+
ws.workspace.download.side_effect = lambda _, *, format: io.BytesIO("abc".encode())
77+
cache = WorkspaceCache(ws)
78+
path = cache.get_path("path")
79+
_ = path.read_text()
80+
path.rename("abcd")
81+
_ = path.read_text()
82+
assert ws.workspace.download.call_count == 3 # rename reads the old content
83+
84+
85+
def test_encoding_is_guessed_after_download():
86+
ws = mock_workspace_client()
87+
ws.workspace.download.side_effect = lambda _, *, format: io.BytesIO("abc".encode())
88+
cache = WorkspaceCache(ws)
89+
path = cache.get_path("path")
90+
_ = path.read_text()
91+
guess_encoding(path)
92+
93+
94+
@pytest.mark.parametrize(
95+
"mode, data",
96+
[
97+
("r", io.BytesIO("abc".encode("utf-8-sig"))),
98+
("rb", io.BytesIO("abc".encode("utf-8-sig"))),
99+
],
100+
)
101+
def test_sequential_read_completes(mode, data):
102+
ws = mock_workspace_client()
103+
ws.workspace.download.side_effect = lambda _, *, format: data
104+
cache = WorkspaceCache(ws)
105+
path = cache.get_path("path")
106+
with path.open(mode) as file:
107+
count = 0
108+
while _ := file.read(1):
109+
count = count + 1
110+
if count > 10:
111+
break
112+
assert count < 10

0 commit comments

Comments
 (0)