Skip to content

Commit 796bd26

Browse files
authored
Merge pull request #4370 from markotoplak/from_table-cache-less
[FIX] Table.from_table: fix caching with reused ids
2 parents 098d97d + b175dfa commit 796bd26

File tree

2 files changed

+160
-7
lines changed

2 files changed

+160
-7
lines changed

Orange/data/table.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import threading
44
import warnings
5+
import weakref
56
import zlib
67
from collections import Iterable, Sequence, Sized
78
from functools import reduce
@@ -316,6 +317,12 @@ def from_table(cls, domain, source, row_indices=...):
316317
:rtype: Orange.data.Table
317318
"""
318319

320+
def valid_refs(weakrefs):
321+
for r in weakrefs:
322+
if r() is None:
323+
return False
324+
return True
325+
319326
def get_columns(row_indices, src_cols, n_rows, dtype=np.float64,
320327
is_sparse=False, variables=[]):
321328
if not len(src_cols):
@@ -356,10 +363,13 @@ def get_columns(row_indices, src_cols, n_rows, dtype=np.float64,
356363
a[:, i] = variables[i].Unknown
357364
elif not isinstance(col, Integral):
358365
if isinstance(col, SharedComputeValue):
359-
if (id(col.compute_shared), id(source)) not in shared_cache:
360-
shared_cache[id(col.compute_shared), id(source)] = \
361-
col.compute_shared(source)
362-
shared = shared_cache[id(col.compute_shared), id(source)]
366+
shared, weakrefs = shared_cache.get((id(col.compute_shared), id(source)),
367+
(None, None))
368+
if shared is None or not valid_refs(weakrefs):
369+
shared, _ = shared_cache[(id(col.compute_shared), id(source))] = \
370+
col.compute_shared(source), \
371+
(weakref.ref(col.compute_shared), weakref.ref(source))
372+
363373
if row_indices is not ...:
364374
a[:, i] = match_density(
365375
col(source, shared_data=shared)[row_indices])
@@ -389,8 +399,9 @@ def get_columns(row_indices, src_cols, n_rows, dtype=np.float64,
389399
if new_cache:
390400
_thread_local.conversion_cache = {}
391401
else:
392-
cached = _thread_local.conversion_cache.get((id(domain), id(source)))
393-
if cached:
402+
cached, weakrefs = \
403+
_thread_local.conversion_cache.get((id(domain), id(source)), (None, None))
404+
if cached and valid_refs(weakrefs):
394405
return cached
395406
if domain is source.domain:
396407
table = cls.from_table_rows(source, row_indices)
@@ -443,7 +454,8 @@ def get_columns(row_indices, src_cols, n_rows, dtype=np.float64,
443454
else:
444455
cls._init_ids(self)
445456
self.attributes = getattr(source, 'attributes', {})
446-
_thread_local.conversion_cache[(id(domain), id(source))] = self
457+
_thread_local.conversion_cache[(id(domain), id(source))] = \
458+
self, (weakref.ref(domain), weakref.ref(source))
447459
return self
448460
finally:
449461
if new_cache:

Orange/tests/test_table.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from Orange import data
1818
from Orange.data import (filter, Unknown, Variable, Table, DiscreteVariable,
1919
ContinuousVariable, Domain, StringVariable)
20+
from Orange.data.util import SharedComputeValue
2021
from Orange.tests import test_dirname
2122
from Orange.data.table import _optimize_indices
2223

@@ -2693,5 +2694,145 @@ def run_from_table():
26932694
self.assertLess(duration, 0.5)
26942695

26952696

2697+
class PreprocessComputeValue:
2698+
2699+
def __init__(self, domain, callback):
2700+
self.domain = domain
2701+
self.callback = callback
2702+
2703+
def __call__(self, data_):
2704+
self.callback(data_)
2705+
data_.transform(self.domain)
2706+
return data_.X[:, 0]
2707+
2708+
2709+
class PreprocessShared:
2710+
2711+
def __init__(self, domain, callback):
2712+
self.domain = domain
2713+
self.callback = callback
2714+
2715+
def __call__(self, data_):
2716+
self.callback(data_)
2717+
data_.transform(self.domain)
2718+
return True
2719+
2720+
2721+
class PreprocessSharedComputeValue(SharedComputeValue):
2722+
2723+
def __init__(self, callback, shared):
2724+
super().__init__(compute_shared=shared)
2725+
self.callback = callback
2726+
2727+
# pylint: disable=arguments-differ
2728+
def compute(self, data_, shared_data):
2729+
self.callback(data_)
2730+
return data_.X[:, 0]
2731+
2732+
2733+
def preprocess_domain_single(domain, callback):
2734+
""" Preprocess domain with single-source compute values.
2735+
"""
2736+
return Domain([
2737+
ContinuousVariable(name=at.name,
2738+
compute_value=PreprocessComputeValue(Domain([at]), callback))
2739+
for at in domain.attributes])
2740+
2741+
2742+
def preprocess_domain_shared(domain, callback, callback_shared):
2743+
""" Preprocess domain with shared compute values.
2744+
"""
2745+
shared = PreprocessShared(domain, callback_shared)
2746+
return Domain([
2747+
ContinuousVariable(name=at.name,
2748+
compute_value=PreprocessSharedComputeValue(callback, shared))
2749+
for at in domain.attributes])
2750+
2751+
2752+
def preprocess_domain_single_stupid(domain, callback):
2753+
""" Preprocess domain with single-source compute values with stupid
2754+
implementation: before applying it, instead of transforming just one column
2755+
into the input domain, do a needless transform of the whole domain.
2756+
"""
2757+
return Domain([
2758+
ContinuousVariable(name=at.name,
2759+
compute_value=PreprocessComputeValue(domain, callback))
2760+
for at in domain.attributes])
2761+
2762+
2763+
class EfficientTransformTests(unittest.TestCase):
2764+
2765+
def setUp(self):
2766+
self.iris = Table("iris")
2767+
2768+
def test_simple(self):
2769+
call_cv = Mock()
2770+
d1 = preprocess_domain_single(self.iris.domain, call_cv)
2771+
self.iris.transform(d1)
2772+
self.assertEqual(4, call_cv.call_count)
2773+
2774+
def test_shared(self):
2775+
call_cv = Mock()
2776+
call_shared = Mock()
2777+
d1 = preprocess_domain_shared(self.iris.domain, call_cv, call_shared)
2778+
self.iris.transform(d1)
2779+
self.assertEqual(4, call_cv.call_count)
2780+
self.assertEqual(1, call_shared.call_count)
2781+
2782+
def test_simple_simple_shared(self):
2783+
call_cv = Mock()
2784+
d1 = preprocess_domain_single(self.iris.domain, call_cv)
2785+
d2 = preprocess_domain_single(d1, call_cv)
2786+
call_shared = Mock()
2787+
d3 = preprocess_domain_shared(d2, call_cv, call_shared)
2788+
self.iris.transform(d3)
2789+
self.assertEqual(1, call_shared.call_count)
2790+
self.assertEqual(12, call_cv.call_count)
2791+
2792+
def test_simple_simple_shared_simple(self):
2793+
call_cv = Mock()
2794+
d1 = preprocess_domain_single(self.iris.domain, call_cv)
2795+
d2 = preprocess_domain_single(d1, call_cv)
2796+
call_shared = Mock()
2797+
d3 = preprocess_domain_shared(d2, call_cv, call_shared)
2798+
d4 = preprocess_domain_single(d3, call_cv)
2799+
self.iris.transform(d4)
2800+
self.assertEqual(1, call_shared.call_count)
2801+
self.assertEqual(16, call_cv.call_count)
2802+
2803+
def test_simple_simple_shared_simple_shared_simple(self):
2804+
call_cv = Mock()
2805+
d1 = preprocess_domain_single(self.iris.domain, call_cv)
2806+
d2 = preprocess_domain_single(d1, call_cv)
2807+
call_shared = Mock()
2808+
d3 = preprocess_domain_shared(d2, call_cv, call_shared)
2809+
d4 = preprocess_domain_single(d3, call_cv)
2810+
d5 = preprocess_domain_shared(d4, call_cv, call_shared)
2811+
d6 = preprocess_domain_single(d5, call_cv)
2812+
self.iris.transform(d6)
2813+
self.assertEqual(2, call_shared.call_count)
2814+
self.assertEqual(24, call_cv.call_count)
2815+
2816+
def test_same_simple_shared_union(self):
2817+
call_cv = Mock()
2818+
call_shared = Mock()
2819+
call_cvs = Mock()
2820+
same_simple = preprocess_domain_single(self.iris.domain, call_cv)
2821+
s1 = preprocess_domain_shared(same_simple, call_cvs, call_shared)
2822+
s2 = preprocess_domain_shared(same_simple, call_cvs, call_shared)
2823+
ndom = Domain(s1.attributes + s2.attributes)
2824+
self.iris.transform(ndom)
2825+
self.assertEqual(2, call_shared.call_count)
2826+
self.assertEqual(4, call_cv.call_count)
2827+
self.assertEqual(8, call_cvs.call_count)
2828+
2829+
def test_simple_simple_stupid(self):
2830+
call_cv = Mock()
2831+
d1 = preprocess_domain_single_stupid(self.iris.domain, call_cv)
2832+
d2 = preprocess_domain_single_stupid(d1, call_cv)
2833+
self.iris.transform(d2)
2834+
self.assertEqual(8, call_cv.call_count)
2835+
2836+
26962837
if __name__ == "__main__":
26972838
unittest.main()

0 commit comments

Comments
 (0)