Skip to content

Commit 2ee62e2

Browse files
CodyCBakerPhDCody Bakerpre-commit-ci[bot]bendichter
authored
Add caching for repeated table data calls (#230)
* added utility for easy data caching * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * debugged * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more general debug * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * swap to using selections and h5py Datasets * Update nwbinspector/utils.py Co-authored-by: Ben Dichter <[email protected]> * support tuple caching * support array slicing for cached data selection * fixed hashing issues; seeing if table message is still changed * fix table test * fixed tests? Co-authored-by: Cody Baker <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ben Dichter <[email protected]>
1 parent cef1bdb commit 2ee62e2

File tree

4 files changed

+55
-17
lines changed

4 files changed

+55
-17
lines changed

nwbinspector/checks/tables.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from ..register_checks import register_check, InspectorMessage, Importance
1111
from ..utils import (
12+
_cache_data_selection,
1213
format_byte_size,
1314
is_ascending_series,
1415
is_dict_in_string,
@@ -55,7 +56,7 @@ def check_time_interval_time_columns(time_intervals: TimeIntervals, nelems: int
5556
unsorted_cols = []
5657
for column in time_intervals.columns:
5758
if column.name[-5:] == "_time":
58-
if not is_ascending_series(column, nelems):
59+
if not is_ascending_series(column.data, nelems):
5960
unsorted_cols.append(column.name)
6061
if unsorted_cols:
6162
return InspectorMessage(
@@ -79,7 +80,11 @@ def check_time_intervals_stop_after_start(time_intervals: TimeIntervals, nelems:
7980
very long so you don't need to load the entire array into memory. Use None to
8081
load the entire arrays.
8182
"""
82-
if np.any(np.asarray(time_intervals["stop_time"][:nelems]) - np.asarray(time_intervals["start_time"][:nelems]) < 0):
83+
if np.any(
84+
np.asarray(_cache_data_selection(data=time_intervals["stop_time"].data, selection=slice(nelems)))
85+
- np.asarray(_cache_data_selection(data=time_intervals["start_time"].data, selection=slice(nelems)))
86+
< 0
87+
):
8388
return InspectorMessage(
8489
message=(
8590
"stop_times should be greater than start_times. Make sure the stop times are with respect to the "
@@ -106,7 +111,7 @@ def check_column_binary_capability(table: DynamicTable, nelems: int = 200):
106111
if np.asarray(column.data[0]).itemsize == 1:
107112
continue # already boolean, int8, or uint8
108113
try:
109-
unique_values = np.unique(column.data[:nelems])
114+
unique_values = np.unique(_cache_data_selection(data=column.data, selection=slice(nelems)))
110115
except TypeError: # some contained objects are unhashable or have no comparison defined
111116
continue
112117
if unique_values.size != 2:
@@ -174,7 +179,7 @@ def check_table_values_for_dict(table: DynamicTable, nelems: int = 200):
174179
for column in table.columns:
175180
if not hasattr(column, "data") or isinstance(column, VectorIndex) or not isinstance(column.data[0], str):
176181
continue
177-
for string in column.data[:nelems]:
182+
for string in _cache_data_selection(data=column.data, selection=slice(nelems)):
178183
if is_dict_in_string(string=string):
179184
message = (
180185
f"The column '{column.name}' contains a string value that contains a dictionary! Please "

nwbinspector/utils.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,48 @@
22
import os
33
import re
44
import json
5-
import numpy as np
6-
from typing import TypeVar, Optional, List, Dict, Callable
5+
from typing import TypeVar, Union, Optional, List, Dict, Callable, Tuple
76
from pathlib import Path
87
from importlib import import_module
98
from packaging import version
109
from time import sleep
10+
from functools import lru_cache
11+
12+
import h5py
13+
import numpy as np
14+
from numpy.typing import ArrayLike
15+
1116

1217
PathType = TypeVar("PathType", str, Path) # For types that can be either files or folders
1318
FilePathType = TypeVar("FilePathType", str, Path)
1419
OptionalListOfStrings = Optional[List[str]]
1520

1621
dict_regex = r"({.+:.+})"
22+
MAX_CACHE_ITEMS = 1000 # lru_cache default is 128 calls of matching input/output, but might need more to get use here
23+
24+
25+
@lru_cache(maxsize=MAX_CACHE_ITEMS)
26+
def _cache_data_retrieval_command(
27+
data: h5py.Dataset, reduced_selection: Tuple[Tuple[Optional[int], Optional[int], Optional[int]]]
28+
) -> np.ndarray:
29+
"""LRU caching for _cache_data_selection cannot be applied to list inputs; this expects the tuple or Dataset."""
30+
selection = tuple([slice(*reduced_slice) for reduced_slice in reduced_selection]) # reconstitute the slices
31+
return data[selection]
32+
33+
34+
def _cache_data_selection(data: Union[h5py.Dataset, ArrayLike], selection: Union[slice, Tuple[slice]]) -> np.ndarray:
35+
"""Extract the selection lazily from the data object for efficient caching (most beneficial during streaming)."""
36+
if isinstance(data, np.memmap): # Technically np.memmap should be able to support this type of behavior as well
37+
return data[selection] # But they aren't natively hashable either...
38+
if not isinstance(data, h5py.Dataset): # No need to attempt to cache if already an in-memory object
39+
return np.array(data)[selection]
40+
41+
# slices also aren't hashable, but their reduced representation is
42+
if isinstance(selection, slice): # If a single slice
43+
reduced_selection = tuple([selection.__reduce__()[1]]) # if a single slice
44+
else:
45+
reduced_selection = tuple([selection_slice.__reduce__()[1] for selection_slice in selection])
46+
return _cache_data_retrieval_command(data=data, reduced_selection=reduced_selection)
1747

1848

1949
def format_byte_size(byte_size: int, units: str = "SI"):
@@ -52,9 +82,12 @@ def check_regular_series(series: np.ndarray, tolerance_decimals: int = 9):
5282
return len(uniq_diff_ts) == 1
5383

5484

55-
def is_ascending_series(series: np.ndarray, nelems=None):
85+
def is_ascending_series(series: Union[h5py.Dataset, ArrayLike], nelems=None):
5686
"""General purpose function for determining if a series is monotonic increasing."""
57-
return np.all(np.diff(series[:nelems]) > 0)
87+
if isinstance(series, h5py.Dataset):
88+
return np.all(np.diff(_cache_data_selection(data=series, selection=slice(nelems))) > 0)
89+
else:
90+
return np.all(np.diff(series[:nelems]) > 0) # already in memory, no need to cache
5891

5992

6093
def is_dict_in_string(string: str):

tests/unit_tests/test_tables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ def test_binary_int_fail(self):
183183
assert check_column_binary_capability(table=self.table) == [
184184
InspectorMessage(
185185
message=(
186-
"Column 'test_col' uses 'integers' but has binary values [0 1]. Consider making it boolean instead "
187-
f"and renaming the column to start with 'is_'; doing so will save {platform_saved_bytes}."
186+
"Column 'test_col' uses 'integers' but has binary values [0 1]. Consider making it boolean "
187+
f"instead and renaming the column to start with 'is_'; doing so will save {platform_saved_bytes}."
188188
),
189189
importance=Importance.BEST_PRACTICE_SUGGESTION,
190190
check_function_name="check_column_binary_capability",

tests/unit_tests/test_time_series.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,12 @@ def test_check_timestamps_empty_timestamps():
133133
)
134134

135135

136-
def test_check_timestamps_ascending():
136+
def test_pass_check_timestamps_ascending_pass():
137+
time_series = pynwb.TimeSeries(name="test_time_series", unit="test_units", data=[1, 2, 3], timestamps=[1, 2, 3])
138+
assert check_timestamps_ascending(time_series) is None
139+
140+
141+
def test_check_timestamps_ascending_fail():
137142
time_series = pynwb.TimeSeries(name="test_time_series", unit="test_units", data=[1, 2, 3], timestamps=[1, 3, 2])
138143
assert check_timestamps_ascending(time_series) == InspectorMessage(
139144
message="test_time_series timestamps are not ascending.",
@@ -145,11 +150,6 @@ def test_check_timestamps_ascending():
145150
)
146151

147152

148-
def test_pass_check_timestamps_ascending():
149-
time_series = pynwb.TimeSeries(name="test_time_series", unit="test_units", data=[1, 2, 3], timestamps=[1, 2, 3])
150-
assert check_timestamps_ascending(time_series) is None
151-
152-
153153
def test_check_missing_unit_pass():
154154
time_series = pynwb.TimeSeries(name="test_time_series", unit="test_units", data=[1, 2, 3], timestamps=[1, 2, 3])
155155
assert check_missing_unit(time_series) is None
@@ -169,7 +169,7 @@ def test_check_missing_unit_fail():
169169

170170
def test_check_positive_resolution_pass():
171171
time_series = pynwb.TimeSeries(name="test", unit="test_units", data=[1, 2, 3], timestamps=[1, 2, 3], resolution=3.4)
172-
assert check_timestamps_ascending(time_series) is None
172+
assert check_resolution(time_series) is None
173173

174174

175175
def test_check_unknown_resolution_pass():

0 commit comments

Comments
 (0)