|
17 | 17 | from Orange import data |
18 | 18 | from Orange.data import (filter, Unknown, Variable, Table, DiscreteVariable, |
19 | 19 | ContinuousVariable, Domain, StringVariable) |
| 20 | +from Orange.data.util import SharedComputeValue |
20 | 21 | from Orange.tests import test_dirname |
21 | 22 | from Orange.data.table import _optimize_indices |
22 | 23 |
|
@@ -2693,5 +2694,145 @@ def run_from_table(): |
2693 | 2694 | self.assertLess(duration, 0.5) |
2694 | 2695 |
|
2695 | 2696 |
|
| 2697 | +class PreprocessComputeValue: |
| 2698 | + |
| 2699 | + def __init__(self, domain, callback): |
| 2700 | + self.domain = domain |
| 2701 | + self.callback = callback |
| 2702 | + |
| 2703 | + def __call__(self, data_): |
| 2704 | + self.callback(data_) |
| 2705 | + data_.transform(self.domain) |
| 2706 | + return data_.X[:, 0] |
| 2707 | + |
| 2708 | + |
| 2709 | +class PreprocessShared: |
| 2710 | + |
| 2711 | + def __init__(self, domain, callback): |
| 2712 | + self.domain = domain |
| 2713 | + self.callback = callback |
| 2714 | + |
| 2715 | + def __call__(self, data_): |
| 2716 | + self.callback(data_) |
| 2717 | + data_.transform(self.domain) |
| 2718 | + return True |
| 2719 | + |
| 2720 | + |
| 2721 | +class PreprocessSharedComputeValue(SharedComputeValue): |
| 2722 | + |
| 2723 | + def __init__(self, callback, shared): |
| 2724 | + super().__init__(compute_shared=shared) |
| 2725 | + self.callback = callback |
| 2726 | + |
| 2727 | + # pylint: disable=arguments-differ |
| 2728 | + def compute(self, data_, shared_data): |
| 2729 | + self.callback(data_) |
| 2730 | + return data_.X[:, 0] |
| 2731 | + |
| 2732 | + |
| 2733 | +def preprocess_domain_single(domain, callback): |
| 2734 | + """ Preprocess domain with single-source compute values. |
| 2735 | + """ |
| 2736 | + return Domain([ |
| 2737 | + ContinuousVariable(name=at.name, |
| 2738 | + compute_value=PreprocessComputeValue(Domain([at]), callback)) |
| 2739 | + for at in domain.attributes]) |
| 2740 | + |
| 2741 | + |
| 2742 | +def preprocess_domain_shared(domain, callback, callback_shared): |
| 2743 | + """ Preprocess domain with shared compute values. |
| 2744 | + """ |
| 2745 | + shared = PreprocessShared(domain, callback_shared) |
| 2746 | + return Domain([ |
| 2747 | + ContinuousVariable(name=at.name, |
| 2748 | + compute_value=PreprocessSharedComputeValue(callback, shared)) |
| 2749 | + for at in domain.attributes]) |
| 2750 | + |
| 2751 | + |
| 2752 | +def preprocess_domain_single_stupid(domain, callback): |
| 2753 | + """ Preprocess domain with single-source compute values with stupid |
| 2754 | + implementation: before applying it, instead of transforming just one column |
| 2755 | + into the input domain, do a needless transform of the whole domain. |
| 2756 | + """ |
| 2757 | + return Domain([ |
| 2758 | + ContinuousVariable(name=at.name, |
| 2759 | + compute_value=PreprocessComputeValue(domain, callback)) |
| 2760 | + for at in domain.attributes]) |
| 2761 | + |
| 2762 | + |
| 2763 | +class EfficientTransformTests(unittest.TestCase): |
| 2764 | + |
| 2765 | + def setUp(self): |
| 2766 | + self.iris = Table("iris") |
| 2767 | + |
| 2768 | + def test_simple(self): |
| 2769 | + call_cv = Mock() |
| 2770 | + d1 = preprocess_domain_single(self.iris.domain, call_cv) |
| 2771 | + self.iris.transform(d1) |
| 2772 | + self.assertEqual(4, call_cv.call_count) |
| 2773 | + |
| 2774 | + def test_shared(self): |
| 2775 | + call_cv = Mock() |
| 2776 | + call_shared = Mock() |
| 2777 | + d1 = preprocess_domain_shared(self.iris.domain, call_cv, call_shared) |
| 2778 | + self.iris.transform(d1) |
| 2779 | + self.assertEqual(4, call_cv.call_count) |
| 2780 | + self.assertEqual(1, call_shared.call_count) |
| 2781 | + |
| 2782 | + def test_simple_simple_shared(self): |
| 2783 | + call_cv = Mock() |
| 2784 | + d1 = preprocess_domain_single(self.iris.domain, call_cv) |
| 2785 | + d2 = preprocess_domain_single(d1, call_cv) |
| 2786 | + call_shared = Mock() |
| 2787 | + d3 = preprocess_domain_shared(d2, call_cv, call_shared) |
| 2788 | + self.iris.transform(d3) |
| 2789 | + self.assertEqual(1, call_shared.call_count) |
| 2790 | + self.assertEqual(12, call_cv.call_count) |
| 2791 | + |
| 2792 | + def test_simple_simple_shared_simple(self): |
| 2793 | + call_cv = Mock() |
| 2794 | + d1 = preprocess_domain_single(self.iris.domain, call_cv) |
| 2795 | + d2 = preprocess_domain_single(d1, call_cv) |
| 2796 | + call_shared = Mock() |
| 2797 | + d3 = preprocess_domain_shared(d2, call_cv, call_shared) |
| 2798 | + d4 = preprocess_domain_single(d3, call_cv) |
| 2799 | + self.iris.transform(d4) |
| 2800 | + self.assertEqual(1, call_shared.call_count) |
| 2801 | + self.assertEqual(16, call_cv.call_count) |
| 2802 | + |
| 2803 | + def test_simple_simple_shared_simple_shared_simple(self): |
| 2804 | + call_cv = Mock() |
| 2805 | + d1 = preprocess_domain_single(self.iris.domain, call_cv) |
| 2806 | + d2 = preprocess_domain_single(d1, call_cv) |
| 2807 | + call_shared = Mock() |
| 2808 | + d3 = preprocess_domain_shared(d2, call_cv, call_shared) |
| 2809 | + d4 = preprocess_domain_single(d3, call_cv) |
| 2810 | + d5 = preprocess_domain_shared(d4, call_cv, call_shared) |
| 2811 | + d6 = preprocess_domain_single(d5, call_cv) |
| 2812 | + self.iris.transform(d6) |
| 2813 | + self.assertEqual(2, call_shared.call_count) |
| 2814 | + self.assertEqual(24, call_cv.call_count) |
| 2815 | + |
| 2816 | + def test_same_simple_shared_union(self): |
| 2817 | + call_cv = Mock() |
| 2818 | + call_shared = Mock() |
| 2819 | + call_cvs = Mock() |
| 2820 | + same_simple = preprocess_domain_single(self.iris.domain, call_cv) |
| 2821 | + s1 = preprocess_domain_shared(same_simple, call_cvs, call_shared) |
| 2822 | + s2 = preprocess_domain_shared(same_simple, call_cvs, call_shared) |
| 2823 | + ndom = Domain(s1.attributes + s2.attributes) |
| 2824 | + self.iris.transform(ndom) |
| 2825 | + self.assertEqual(2, call_shared.call_count) |
| 2826 | + self.assertEqual(4, call_cv.call_count) |
| 2827 | + self.assertEqual(8, call_cvs.call_count) |
| 2828 | + |
| 2829 | + def test_simple_simple_stupid(self): |
| 2830 | + call_cv = Mock() |
| 2831 | + d1 = preprocess_domain_single_stupid(self.iris.domain, call_cv) |
| 2832 | + d2 = preprocess_domain_single_stupid(d1, call_cv) |
| 2833 | + self.iris.transform(d2) |
| 2834 | + self.assertEqual(8, call_cv.call_count) |
| 2835 | + |
| 2836 | + |
2696 | 2837 | if __name__ == "__main__": |
2697 | 2838 | unittest.main() |
0 commit comments