Skip to content

Commit db669cb

Browse files
authored
Merge pull request #38 from PySport/feat/add-load-files-cache
Add load files cache
2 parents da565e4 + c3ad621 commit db669cb

File tree

2 files changed

+145
-4
lines changed

2 files changed

+145
-4
lines changed

ingestify/application/dataset_store.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import mimetypes
55
import os
66
import shutil
7+
from contextlib import contextmanager
8+
import threading
79
from dataclasses import asdict
810
from io import BytesIO
911

@@ -47,6 +49,8 @@ def __init__(
4749
self.storage_compression_method = "gzip"
4850
self.bucket = bucket
4951
self.event_bus: Optional[EventBus] = None
52+
# Create thread-local storage for caching
53+
self._thread_local = threading.local()
5054

5155
# def __getstate__(self):
5256
# return {"file_repository": self.file_repository, "bucket": self.bucket}
@@ -58,6 +62,34 @@ def dispatch(self, event):
5862
if self.event_bus:
5963
self.event_bus.dispatch(event)
6064

65+
@contextmanager
66+
def with_file_cache(self):
67+
"""Context manager to enable file caching during its scope.
68+
69+
Files loaded within this context will be cached and reused,
70+
avoiding multiple downloads of the same file.
71+
72+
Example:
73+
# Without caching (loads files twice)
74+
analyzer1 = StatsAnalyzer(store, dataset)
75+
analyzer2 = VisualizationTool(store, dataset)
76+
77+
# With caching (files are loaded once and shared)
78+
with store.with_file_cache():
79+
analyzer1 = StatsAnalyzer(store, dataset)
80+
analyzer2 = VisualizationTool(store, dataset)
81+
"""
82+
# Enable caching for this thread
83+
self._thread_local.use_file_cache = True
84+
self._thread_local.file_cache = {}
85+
86+
try:
87+
yield
88+
finally:
89+
# Disable caching for this thread
90+
self._thread_local.use_file_cache = False
91+
self._thread_local.file_cache = {}
92+
6193
def save_ingestion_job_summary(self, ingestion_job_summary):
6294
self.dataset_repository.save_ingestion_job_summary(ingestion_job_summary)
6395

@@ -384,10 +416,21 @@ def get_stream(file_):
384416
self.file_repository.load_content(storage_path=file_.storage_path)
385417
)
386418

387-
loaded_file = LoadedFile(
388-
stream_=get_stream if lazy else get_stream(file),
389-
**file.model_dump(),
390-
)
419+
def make_loaded_file():
420+
return LoadedFile(
421+
stream_=get_stream if lazy else get_stream(file),
422+
**file.model_dump(),
423+
)
424+
425+
# Using getattr with a default value of False - simple one-liner
426+
if getattr(self._thread_local, "use_file_cache", False):
427+
key = (dataset.dataset_id, current_revision.revision_id, file.file_id)
428+
if key not in self._thread_local.file_cache:
429+
self._thread_local.file_cache[key] = make_loaded_file()
430+
loaded_file = self._thread_local.file_cache[key]
431+
else:
432+
loaded_file = make_loaded_file()
433+
391434
files[file.file_id] = loaded_file
392435
return FileCollection(files, auto_rewind=auto_rewind)
393436

ingestify/tests/test_file_cache.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import pytest
2+
from io import BytesIO
3+
from unittest.mock import patch
4+
from datetime import datetime, timezone
5+
6+
from ingestify.main import get_engine
7+
from ingestify.domain import Dataset, Identifier, Revision, File
8+
from ingestify.domain.models.dataset.revision import RevisionSource, SourceType
9+
10+
11+
def test_file_cache(config_file):
12+
"""Test file caching with the with_file_cache context manager."""
13+
# Get engine from the fixture
14+
engine = get_engine(config_file, "main")
15+
store = engine.store
16+
17+
# Create a timestamp for test data
18+
now = datetime.now(timezone.utc)
19+
20+
# Create a test file
21+
test_file = File(
22+
file_id="test_file_id",
23+
data_feed_key="test_file",
24+
tag="test_tag",
25+
data_serialization_format="txt",
26+
storage_path="test/path",
27+
storage_size=100,
28+
storage_compression_method="none",
29+
created_at=now,
30+
modified_at=now,
31+
size=100,
32+
content_type="text/plain",
33+
data_spec_version="v1",
34+
)
35+
36+
# Create a test revision with the file
37+
revision = Revision(
38+
revision_id=1,
39+
created_at=now,
40+
description="Test revision",
41+
modified_files=[test_file],
42+
source={"source_type": SourceType.MANUAL, "source_id": "test"},
43+
)
44+
45+
# Create a test dataset with the revision
46+
dataset = Dataset(
47+
bucket="test-bucket",
48+
dataset_id="test-dataset",
49+
name="Test Dataset",
50+
state="COMPLETE",
51+
identifier=Identifier(test_id=1),
52+
dataset_type="test",
53+
provider="test-provider",
54+
metadata={},
55+
created_at=now,
56+
updated_at=now,
57+
last_modified_at=now,
58+
revisions=[revision],
59+
)
60+
61+
# Create a simple pass-through reader function to replace the gzip reader
62+
def simple_reader(stream):
63+
return stream
64+
65+
# Mock both the file repository and the _prepare_read_stream method
66+
with patch.object(
67+
store.file_repository, "load_content"
68+
) as mock_load_content, patch.object(
69+
store, "_prepare_read_stream"
70+
) as mock_prepare_read_stream:
71+
72+
# Set up the mocks
73+
mock_load_content.return_value = BytesIO(b"test content")
74+
mock_prepare_read_stream.return_value = (simple_reader, "")
75+
76+
# Test without caching - should load files twice
77+
store.load_files(dataset)
78+
store.load_files(dataset)
79+
80+
# Should have called load_content twice (without caching)
81+
assert mock_load_content.call_count == 2
82+
83+
# Reset the mock
84+
mock_load_content.reset_mock()
85+
86+
# Test with caching - should load files only once
87+
with store.with_file_cache():
88+
store.load_files(dataset)
89+
store.load_files(dataset)
90+
91+
# Should have called load_content only once (with caching)
92+
assert mock_load_content.call_count == 1
93+
94+
# After exiting context, caching should be disabled
95+
store.load_files(dataset)
96+
97+
# Should have called load_content again
98+
assert mock_load_content.call_count == 2

0 commit comments

Comments
 (0)