Skip to content

Commit b366dd0

Browse files
committed
improve testing cache
1 parent 965a13e commit b366dd0

File tree

15 files changed

+1123
-183
lines changed

15 files changed

+1123
-183
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
2+
"""Cache for test evaluations"""
3+
4+
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
5+
import hashlib
6+
import data_algebra
7+
import data_algebra.db_model
8+
9+
10+
def hash_data_frame(d) -> str:
11+
"""
12+
Get a hash code representing a data frame.
13+
14+
:param d: data frame
15+
:return: hash code as a string
16+
"""
17+
data_algebra.default_data_model.is_appropriate_data_instance(d)
18+
hash_str = hashlib.sha256(
19+
data_algebra.default_data_model.pd.util.hash_pandas_object(d).values
20+
).hexdigest()
21+
return f'{d.shape}_{list(d.columns)}_{hash_str}'
22+
23+
24+
class EvalKey(NamedTuple):
25+
"""Carry description of data transform key"""
26+
db_model_name: str
27+
sql: str
28+
dat_map_list: Tuple[Tuple[str, str], ...]
29+
30+
31+
def make_cache_key(
32+
*,
33+
db_model: data_algebra.db_model.DBModel,
34+
sql: str,
35+
data_map: Dict[str, Any],
36+
):
37+
"""
38+
Create an immutable, hashable key.
39+
"""
40+
assert isinstance(db_model, data_algebra.db_model.DBModel)
41+
assert isinstance(sql, str)
42+
assert isinstance(data_map, dict)
43+
data_map_keys = list(data_map.keys())
44+
data_map_keys.sort()
45+
for k in data_map_keys:
46+
assert isinstance(k, str)
47+
assert data_algebra.default_data_model.is_appropriate_data_instance(data_map[k])
48+
return EvalKey(
49+
db_model_name=str(db_model),
50+
sql=sql,
51+
dat_map_list=tuple([(k, hash_data_frame(data_map[k])) for k in data_map_keys])
52+
)
53+
54+
55+
class ResultCache:
56+
"""Cache for test results. Maps keys to data frames."""
57+
dirty: bool
58+
data_cache: Optional[Dict[str, Any]]
59+
result_cache: Dict[EvalKey, Any]
60+
61+
def __init__(self):
62+
self.dirty = False
63+
self.data_cache = dict()
64+
self.result_cache = dict()
65+
66+
def get(self,
67+
*,
68+
db_model: data_algebra.db_model.DBModel,
69+
sql: str,
70+
data_map: Dict[str, Any]):
71+
"""get result from cache, raise KeyError if not present"""
72+
k = make_cache_key(
73+
db_model=db_model,
74+
sql=sql,
75+
data_map=data_map)
76+
res = self.result_cache[k]
77+
assert data_algebra.default_data_model.is_appropriate_data_instance(res)
78+
return res.copy()
79+
80+
def store(self,
81+
*,
82+
db_model: data_algebra.db_model.DBModel,
83+
sql: str,
84+
data_map: Dict[str, Any],
85+
res) -> None:
86+
"""Store result to cache, mark dirty if change."""
87+
assert data_algebra.default_data_model.is_appropriate_data_instance(res)
88+
op_key = make_cache_key(
89+
db_model=db_model,
90+
sql=sql,
91+
data_map=data_map)
92+
try:
93+
previous = self.result_cache[op_key]
94+
if previous.equals(res):
95+
return
96+
except KeyError:
97+
pass
98+
self.dirty = True
99+
self.result_cache[op_key] = res.copy()
100+
# values saved for debugging
101+
if self.dirty is not None:
102+
for d in (list(data_map.values()) + [res]):
103+
d_key = hash_data_frame(d)
104+
# assuming no spurious key collisions
105+
if d_key not in self.data_cache.keys():
106+
self.data_cache[d_key] = d.copy()

build/lib/data_algebra/test_util.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
"""
44

55
import pickle
6-
import hashlib
76
import traceback
87
from typing import Any, Optional
98

109
import numpy
1110

1211
import data_algebra
12+
import data_algebra.eval_cache
1313
import data_algebra.db_model
1414
import data_algebra.SQLite
1515
import data_algebra.BigQuery
@@ -29,12 +29,12 @@
2929
run_direct_ops_path_tests = False
3030

3131
# global test result cache
32-
global_test_result_cache = None
32+
global_test_result_cache: Optional[data_algebra.eval_cache.ResultCache] = None
3333

3434

3535
def re_parse(ops):
3636
"""
37-
Return copy of object made by dumpint to string via repr() and then evaluating that string.
37+
Return copy of object made by dumping to string via repr() and then evaluating that string.
3838
"""
3939
str1 = repr(ops)
4040
ops2 = eval(
@@ -167,30 +167,18 @@ def equivalent_frames(
167167
return True
168168

169169

170-
def hash_data_frame(d) -> str:
171-
"""
172-
Get a hash code representing a data frame.
173-
174-
:param d: data frame
175-
:return: hash code as a string
176-
"""
177-
return hashlib.sha256(
178-
data_algebra.default_data_model.pd.util.hash_pandas_object(d).values
179-
).hexdigest()
180-
181-
182170
def _run_handle_experiments(
183171
*,
184172
db_handle,
185-
data: Dict,
173+
data: Dict[str, Any],
186174
ops: ViewRepresentation,
187175
sql_statements: Iterable[str],
188176
expect,
189177
float_tol: float = 1e-8,
190178
check_column_order: bool = False,
191179
cols_case_sensitive: bool = False,
192180
check_row_order: bool = False,
193-
test_result_cache: Optional[dict] = None,
181+
test_result_cache: Optional[data_algebra.eval_cache.ResultCache] = None,
194182
alter_cache: bool = True,
195183
test_direct_ops_path=False,
196184
):
@@ -203,26 +191,19 @@ def _run_handle_experiments(
203191
assert db_handle.conn is not None
204192
if isinstance(db_handle.db_model, data_algebra.SQLite.SQLiteModel):
205193
test_direct_ops_path = True
206-
db_handle_key = str(db_handle.db_model)
207194
sql_statements = list(sql_statements)
208195
res_db_sql = list([None] * len(sql_statements)) # extra list() wrapper for PyCharm's type checker
209196
res_db_ops = None
210197
need_to_run = True
211-
dict_keys = list(data.keys())
212-
dict_keys.sort()
213-
data_key = " ".join([k + ":" + hash_data_frame(data[k]) for k in dict_keys])
214-
215-
def mk_key(ii):
216-
"""
217-
Build sql statement key.
218-
"""
219-
return db_handle_key + " " + sql_statements[ii] + " " + data_key
220-
221198
# inspect result cache for any prior results
222199
if test_result_cache is not None:
223200
for i in range(len(sql_statements)):
224201
try:
225-
res_db_sql[i] = test_result_cache[mk_key(i)].copy()
202+
res_db_sql[i] = test_result_cache.get(
203+
db_model=db_handle.db_model,
204+
sql=sql_statements[i],
205+
data_map=data,
206+
)
226207
except KeyError:
227208
pass
228209
need_to_run = test_direct_ops_path or numpy.any(
@@ -248,7 +229,12 @@ def mk_key(ii):
248229
and (test_result_cache is not None)
249230
and (res_db_sql_i is not None)
250231
):
251-
test_result_cache[mk_key(i)] = res_db_sql_i.copy()
232+
test_result_cache.store(
233+
db_model=db_handle.db_model,
234+
sql=sql_statements[i],
235+
data_map=data,
236+
res=res_db_sql_i,
237+
)
252238
except AssertionError as ase:
253239
traceback.print_exc()
254240
caught = ase

0 commit comments

Comments
 (0)