Skip to content

Commit c9346b4

Browse files
committed
Cache dataset in-memory
1 parent d0e340a commit c9346b4

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

datastream/dataset.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,27 @@ def zip(datasets: List[Dataset]) -> Dataset[Tuple]:
487487
))
488488
)
489489

490+
def cache(self, key_column):
491+
'''Cache dataset in-memory based on key column.'''
492+
from functools import lru_cache
493+
494+
key_mapping = dict(zip(
495+
self.dataframe[key_column],
496+
range(len(self)),
497+
))
498+
499+
@lru_cache(maxsize=None)
500+
def only_key(key):
501+
return self.get_item(self.dataframe, key_mapping[key])
502+
503+
return Dataset(
504+
dataframe=self.dataframe,
505+
length=self.length,
506+
get_item=lambda dataframe, index: only_key(
507+
dataframe.iloc[index][key_column]
508+
),
509+
)
510+
490511

491512
def test_equal():
492513
dataset1 = Dataset.from_subscriptable([4, 7, 12])

datastream/datastream.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,13 @@ def multi_sample(self: Datastream[T], n: int) -> Datastream[T]:
253253
MultiSampler.from_number(n, self.dataset),
254254
)
255255

256+
def cache(self, key_column):
257+
'''Cache dataset in-memory. See :func:`Dataset.cache` for details.'''
258+
return Datastream(
259+
self.dataset.cache(key_column),
260+
self.sampler,
261+
)
262+
256263

257264
def test_datastream_merge():
258265

0 commit comments

Comments
 (0)