Skip to content

Commit 1fe342a

Browse files
committed
Updated test cases to include more subplot stacking possibilities
1 parent 46c6eaa commit 1fe342a

File tree

2 files changed

+81
-40
lines changed

2 files changed

+81
-40
lines changed

pandas/tests/plotting/test_common.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytest
22
import numpy as np
33
from pandas import DataFrame
4-
from pandas import unique
54
from pandas.tests.plotting.common import (
65
_check_plot_works,
76
_check_ticks_props,
@@ -59,45 +58,6 @@ def test_colorbar_layout(self):
5958

6059
fig.colorbar(cs0, ax=[axes["A"], axes["B"]], location="right")
6160
DataFrame(x).plot(ax=axes["C"])
62-
63-
def test_bar_subplot_stacking(self):
64-
#GH Issue 61018
65-
test_data = np.random.default_rng(3).integers(0,100,5)
66-
df = DataFrame({"A": test_data, "B": test_data[::-1], "C": test_data[0]})
67-
ax = df.plot(subplots= [('A','B')], kind="bar", stacked=True)
68-
69-
#finds all the rectangles that represent the values from both subplots
70-
data_from_subplots = [[(x.get_x(), x.get_y(), x.get_height()) for x in ax[i].findobj(plt.Rectangle) if x.get_height() in test_data] for i in range(0,2)]
71-
72-
#get xy and height of squares that represent the data graphed from the df
73-
#we would expect the height value of A to be reflected in the Y coord of B in subplot 1
74-
subplot_data_df_list = []
75-
unique_x_loc_list = []
76-
for i in range(0,len(data_from_subplots)):
77-
subplot_data_df= DataFrame(data = data_from_subplots[i], columns = ["x_coord", "y_coord", "height"])
78-
unique_x_loc = unique(subplot_data_df["x_coord"])
79-
80-
subplot_data_df_list.append(subplot_data_df)
81-
unique_x_loc_list.append(unique_x_loc)
82-
83-
#Checks subplot 1
84-
plot_A_df = subplot_data_df_list[0].iloc[:len(test_data)]
85-
plot_B_df = subplot_data_df_list[0].iloc[len(test_data):].reset_index()
86-
total_bar_height = plot_A_df["height"].add(plot_B_df["height"])
87-
#check number of bars matches the number of data plotted
88-
assert len(unique_x_loc_list[0]) == len(test_data)
89-
#checks that the first set of bars are the correct height and that the second one starts at the top of the first, additional checks the combined height of the bars are correct
90-
assert (plot_A_df["height"] == test_data).all()
91-
assert (plot_B_df["y_coord"] == test_data).all()
92-
assert (total_bar_height == test_data + test_data[::-1]).all()
93-
94-
#Checks subplot 2
95-
plot_C_df = subplot_data_df_list[1].iloc[:len(test_data)]
96-
#check number of bars matches the number of data plotted
97-
assert len(unique_x_loc_list[1]) == len(test_data)
98-
#checks that all the bars start at zero and are the correct height
99-
assert (plot_C_df["height"] == test_data[0]).all()
100-
assert (plot_C_df["y_coord"] == 0).all()
10161

10262

10363

pandas/tests/plotting/test_misc.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,3 +681,84 @@ def test_bar_plt_xaxis_intervalrange(self):
681681
(a.get_text() == b.get_text())
682682
for a, b in zip(s.plot.bar().get_xticklabels(), expected)
683683
)
684+
685+
@pytest.fixture(scope="class")
686+
def BSS_data() -> np.array:
687+
yield np.random.default_rng(3).integers(0,100,5)
688+
689+
@pytest.fixture(scope="class")
690+
def BSS_df(BSS_data) -> DataFrame:
691+
BSS_df = DataFrame({"A": BSS_data, "B": BSS_data[::-1], "C": BSS_data[0], "D": BSS_data[-1]})
692+
return BSS_df
693+
694+
def _BSS_xyheight_from_ax_helper(BSS_data, ax, subplot_division):
695+
subplot_data_df_list = []
696+
697+
# get xy and height of squares that represent the data graphed from the df, seperated by subplots
698+
for i in range(len(subplot_division)):
699+
subplot_data = np.array([(x.get_x(), x.get_y(), x.get_height()) for x in ax[i].findobj(plt.Rectangle) if x.get_height() in BSS_data])
700+
subplot_data_df_list.append(DataFrame(data = subplot_data, columns = ["x_coord", "y_coord", "height"]))
701+
702+
return subplot_data_df_list
703+
704+
def _BSS_subplot_checker(BSS_data, BSS_df, subplot_data_df, subplot_columns):
705+
assert_flag = 0
706+
subplot_sliced_by_source = [subplot_data_df.iloc[len(BSS_data) * i : len(BSS_data) * (i+1)].reset_index() for i in range(0, len(subplot_columns))]
707+
expected_total_height = BSS_df.loc[:,subplot_columns].sum(axis=1)
708+
709+
for i in range(len(subplot_columns)):
710+
sliced_df = subplot_sliced_by_source[i]
711+
if i == 0:
712+
#Checks that the bar chart starts y=0
713+
assert (sliced_df["y_coord"] == 0).all
714+
height_iter = sliced_df["y_coord"].add(sliced_df["height"])
715+
else:
716+
height_iter = height_iter + sliced_df["height"]
717+
718+
if i+1 == len(subplot_columns):
719+
#Checks final height matches what is expected
720+
tm.assert_series_equal(height_iter, expected_total_height, check_names = False, check_dtype= False)
721+
722+
else:
723+
#Checks each preceding bar ends where the next one starts
724+
next_start_coord = subplot_sliced_by_source[i+1]["y_coord"]
725+
tm.assert_series_equal(height_iter, next_start_coord, check_names = False, check_dtype= False)
726+
727+
class TestBarSubplotStacked:
728+
#GH Issue 61018
729+
def test_bar_1_subplot_1_double_stacked(self, BSS_data, BSS_df):
730+
columns_used = ["A", "B"]
731+
BSS_df_trimmed = BSS_df[columns_used]
732+
subplot_division = [columns_used]
733+
ax = BSS_df_trimmed.plot(subplots = subplot_division, kind="bar", stacked=True)
734+
subplot_data_df_list = _BSS_xyheight_from_ax_helper(BSS_data, ax, subplot_division)
735+
for i in range(len(subplot_data_df_list)):
736+
_BSS_subplot_checker(BSS_data, BSS_df_trimmed, subplot_data_df_list[i], subplot_division[i])
737+
plt.savefig("1s1d.png")
738+
739+
740+
def test_bar_2_subplot_1_double_stacked(self, BSS_data, BSS_df):
741+
columns_used = ["A", "B", "C"]
742+
BSS_df_trimmed = BSS_df[columns_used]
743+
subplot_division = [(columns_used[0], columns_used[1]), (columns_used[2],)]
744+
ax = BSS_df_trimmed.plot(subplots = subplot_division, kind="bar", stacked=True)
745+
subplot_data_df_list = _BSS_xyheight_from_ax_helper(BSS_data, ax, subplot_division)
746+
for i in range(len(subplot_data_df_list)):
747+
_BSS_subplot_checker(BSS_data, BSS_df_trimmed, subplot_data_df_list[i], subplot_division[i])
748+
plt.savefig("2s1d.png")
749+
750+
def test_bar_2_subplot_2_double_stacked(self, BSS_data, BSS_df):
751+
subplot_division = [('A', 'D'), ('C', 'B')]
752+
ax = BSS_df.plot(subplots = subplot_division, kind="bar", stacked=True)
753+
subplot_data_df_list = _BSS_xyheight_from_ax_helper(BSS_data, ax, subplot_division)
754+
for i in range(len(subplot_data_df_list)):
755+
_BSS_subplot_checker(BSS_data, BSS_df, subplot_data_df_list[i], subplot_division[i])
756+
plt.savefig("2s2d.png")
757+
758+
def test_bar_2_subplots_1_triple_stacked(self, BSS_data, BSS_df):
759+
subplot_division = [('A', 'D', 'C')]
760+
ax = BSS_df.plot(subplots = subplot_division, kind="bar", stacked=True)
761+
subplot_data_df_list = _BSS_xyheight_from_ax_helper(BSS_data, ax, subplot_division)
762+
for i in range(len(subplot_data_df_list)):
763+
_BSS_subplot_checker(BSS_data, BSS_df, subplot_data_df_list[i], subplot_division[i])
764+
plt.savefig("2s1t.png")

0 commit comments

Comments
 (0)