diff --git a/examples/plot_vertical.py b/examples/plot_vertical.py index 5b3113b..54f4c1f 100644 --- a/examples/plot_vertical.py +++ b/examples/plot_vertical.py @@ -7,7 +7,7 @@ """ from matplotlib import pyplot as plt -from upsetplot import generate_counts, plot +from upsetplot import generate_counts, plot, plotting example = generate_counts() plot(example, orientation='vertical') @@ -26,3 +26,16 @@ show_percentages=True) plt.suptitle('With counts and percentages shown') plt.show() + +######################################################################### +# An UpSetplot with additional plots on vertical +# and tuning some visual parameters +example = generate_counts(extra_columns=2) +fig = plotting.UpSet(example, orientation='vertical', + show_counts=True, facecolor="grey", + element_size=75) +fig.add_catplot('swarm', 'value', palette='colorblind') +fig.add_catplot('swarm', 'value1', palette='colorblind') +fig.add_catplot('swarm', 'value2', palette='colorblind') +fig.plot() +plt.show() diff --git a/upsetplot/data.py b/upsetplot/data.py index 76e1531..e58d95d 100644 --- a/upsetplot/data.py +++ b/upsetplot/data.py @@ -8,7 +8,7 @@ import numpy as np -def generate_samples(seed=0, n_samples=10000, n_categories=3): +def generate_samples(seed=0, n_samples=10000, n_categories=3, extra_columns=0): """Generate artificial samples assigned to set intersections Parameters @@ -19,12 +19,16 @@ def generate_samples(seed=0, n_samples=10000, n_categories=3): Number of samples to generate n_categories : int Number of categories (named "cat0", "cat1", ...) to generate + extra_columns : int + If a vector is required,this would indicated the number of additional + columns (named "value", "value1", "value2", ... ) Returns ------- DataFrame Field 'value' is a weight or score for each element. Field 'index' is a unique id for each element. + Field(s) 'value{i}' additional values for multiple-feature samples Index includes a boolean indicator mask for each category. Note: Further fields may be added in future versions. @@ -34,19 +38,25 @@ def generate_samples(seed=0, n_samples=10000, n_categories=3): generate_counts : Generates the counts for each subset of categories corresponding to these samples. """ + assert extra_columns >= 0, 'extra_columns parameter should be possitive' rng = np.random.RandomState(seed) - df = pd.DataFrame({'value': np.zeros(n_samples)}) + len_samples = 1 + extra_columns + df = pd.DataFrame(np.zeros((n_samples, len_samples))) + valuename_lst = [f'value{i}' if i > 0 else 'value' for i in + range(len_samples)] + df.columns = valuename_lst + for i in range(n_categories): - r = rng.rand(n_samples) - df['cat%d' % i] = r > rng.rand() - df['value'] += r + r = rng.rand(n_samples, len_samples) + df[f'cat{i}'] = r[:, 0] > rng.rand() + df[valuename_lst] += r df.reset_index(inplace=True) - df.set_index(['cat%d' % i for i in range(n_categories)], inplace=True) + df.set_index([f'cat{i}' for i in range(n_categories)], inplace=True) return df -def generate_counts(seed=0, n_samples=10000, n_categories=3): +def generate_counts(seed=0, n_samples=10000, n_categories=3, extra_columns=0): """Generate artificial counts corresponding to set intersections Parameters @@ -57,20 +67,30 @@ def generate_counts(seed=0, n_samples=10000, n_categories=3): Number of samples to generate statistics over n_categories : int Number of categories (named "cat0", "cat1", ...) to generate + extra_columns: int + Number of additional features to be use to generate each + sample (value, value1, value2, ...) Returns ------- - Series - Counts indexed by boolean indicator mask for each category. + Series or DataFrame + A Series of counts indexed by boolean indicator mask for each category, + when ``extra_columns`` is 0. Otherwise a DataFrame with column ``value`` + equivalent to the value produced when ``extra_columns`` is 0, as well as + further random variables ``value1``, ``value2``, for extra columns. See Also -------- generate_samples : Generates a DataFrame of samples that these counts are derived from. """ + assert extra_columns >= 0, 'extra_columns parameter should be possitive' df = generate_samples(seed=seed, n_samples=n_samples, - n_categories=n_categories) - return df.value.groupby(level=list(range(n_categories))).count() + n_categories=n_categories, + extra_columns=extra_columns) + df.drop('index', axis=1, inplace=True) + df = df if extra_columns > 0 else df.value + return df.groupby(level=list(range(n_categories))).count() def generate_data(seed=0, n_samples=10000, n_sets=3, aggregated=False): diff --git a/upsetplot/tests/test_data.py b/upsetplot/tests/test_data.py index 5937762..4bd3634 100644 --- a/upsetplot/tests/test_data.py +++ b/upsetplot/tests/test_data.py @@ -3,10 +3,11 @@ import pandas as pd import numpy as np from distutils.version import LooseVersion -from pandas.util.testing import (assert_series_equal, assert_frame_equal, - assert_index_equal) +from pandas.testing import (assert_series_equal, assert_frame_equal, + assert_index_equal) from upsetplot import (from_memberships, from_contents, from_indicators, generate_data) +from upsetplot.data import (generate_samples, generate_counts) @pytest.mark.parametrize('typ', [set, list, tuple, iter]) @@ -207,6 +208,75 @@ def test_from_indicators_equivalence(indicators, data): from_memberships([[], ["cat1"], []], data)) -def test_generate_data_warning(): - with pytest.warns(DeprecationWarning): - generate_data() +class TestGenerateData: + def test_generate_data_warning(self): + ''' + Check the warning araised by the function + ''' + with pytest.warns(DeprecationWarning): + generate_data() + + def test_generate_default(self): + ''' + Check that the generated data by default, fullfills the + correct dimensions of the data + ''' + result = generate_data() + assert len(result.index[0]) == 3 + assert result.shape == (10_000,) + + def test_generate_samples_reproductibility(self): + ''' + This test explores the reproducibility of the results + when a random seed has been set + ''' + import numpy as np + seed = np.random.randint(0, 100) + assert generate_samples(seed=seed).equals(generate_samples(seed=seed)) + + @pytest.mark.parametrize("n_samples", [100, 1_000, 10_000]) + @pytest.mark.parametrize("n_categories", [1, 3]) + @pytest.mark.parametrize("extra_columns", [0, 2]) + def test_generate_samples_shapes(self, n_samples, n_categories, + extra_columns): + ''' + Check the generations of different sample sizes with different + arguments + NOTICE: the generate_samples funcition has one extra + column due to index, unless it is unused and it is removed + ''' + result = generate_samples(n_samples=n_samples, + n_categories=n_categories, + extra_columns=extra_columns) + + if type(result.index[0]) is tuple: + assert len(result.index[0]) == n_categories + else: + assert result.index.is_boolean() + + assert result.shape == (n_samples, extra_columns + 2) + + @pytest.mark.parametrize("n_samples", [100, 1_000, 10_000]) + @pytest.mark.parametrize("extra_columns", [0, 2]) + def test_generate_counts(self, n_samples, extra_columns): + ''' + Test of the function generate_counts + which internally uses generate_samples + ''' + result = generate_counts(n_samples=n_samples, + extra_columns=extra_columns) + if extra_columns: + assert len(result.columns) == extra_columns + 1 + assert (result.sum(axis=0) == n_samples).all() + + @pytest.mark.parametrize("aggregated", [True, False]) + def test_generate_data(self, aggregated): + ''' + Test the return of the deprecated method + generate_data + ''' + data = generate_data(aggregated=aggregated) + if aggregated: + assert data.equals(generate_counts()) + else: + assert data.equals(generate_samples().value) diff --git a/upsetplot/tests/test_upsetplot.py b/upsetplot/tests/test_upsetplot.py index 2ee0e6e..e9c9f19 100644 --- a/upsetplot/tests/test_upsetplot.py +++ b/upsetplot/tests/test_upsetplot.py @@ -43,7 +43,6 @@ def get_all_texts(mpl_artist): 'sort_categories_by', [None, 'input', '-input', 'cardinality', '-cardinality']) def test_process_data_series(x, sort_by, sort_categories_by): - assert x.name == 'value' for subset_size in ['auto', 'sum', 'count']: for sum_over in ['abc', False]: with pytest.raises(ValueError, match='sum_over is not applicable'):