@@ -681,3 +681,84 @@ def test_bar_plt_xaxis_intervalrange(self):
681681 (a .get_text () == b .get_text ())
682682 for a , b in zip (s .plot .bar ().get_xticklabels (), expected )
683683 )
684+
685+ @pytest .fixture (scope = "class" )
686+ def BSS_data () -> np .array :
687+ yield np .random .default_rng (3 ).integers (0 ,100 ,5 )
688+
689+ @pytest .fixture (scope = "class" )
690+ def BSS_df (BSS_data ) -> DataFrame :
691+ BSS_df = DataFrame ({"A" : BSS_data , "B" : BSS_data [::- 1 ], "C" : BSS_data [0 ], "D" : BSS_data [- 1 ]})
692+ return BSS_df
693+
694+ def _BSS_xyheight_from_ax_helper (BSS_data , ax , subplot_division ):
695+ subplot_data_df_list = []
696+
697+ # get xy and height of squares that represent the data graphed from the df, seperated by subplots
698+ for i in range (len (subplot_division )):
699+ subplot_data = np .array ([(x .get_x (), x .get_y (), x .get_height ()) for x in ax [i ].findobj (plt .Rectangle ) if x .get_height () in BSS_data ])
700+ subplot_data_df_list .append (DataFrame (data = subplot_data , columns = ["x_coord" , "y_coord" , "height" ]))
701+
702+ return subplot_data_df_list
703+
704+ def _BSS_subplot_checker (BSS_data , BSS_df , subplot_data_df , subplot_columns ):
705+ assert_flag = 0
706+ subplot_sliced_by_source = [subplot_data_df .iloc [len (BSS_data ) * i : len (BSS_data ) * (i + 1 )].reset_index () for i in range (0 , len (subplot_columns ))]
707+ expected_total_height = BSS_df .loc [:,subplot_columns ].sum (axis = 1 )
708+
709+ for i in range (len (subplot_columns )):
710+ sliced_df = subplot_sliced_by_source [i ]
711+ if i == 0 :
712+ #Checks that the bar chart starts y=0
713+ assert (sliced_df ["y_coord" ] == 0 ).all
714+ height_iter = sliced_df ["y_coord" ].add (sliced_df ["height" ])
715+ else :
716+ height_iter = height_iter + sliced_df ["height" ]
717+
718+ if i + 1 == len (subplot_columns ):
719+ #Checks final height matches what is expected
720+ tm .assert_series_equal (height_iter , expected_total_height , check_names = False , check_dtype = False )
721+
722+ else :
723+ #Checks each preceding bar ends where the next one starts
724+ next_start_coord = subplot_sliced_by_source [i + 1 ]["y_coord" ]
725+ tm .assert_series_equal (height_iter , next_start_coord , check_names = False , check_dtype = False )
726+
727+ class TestBarSubplotStacked :
728+ #GH Issue 61018
729+ def test_bar_1_subplot_1_double_stacked (self , BSS_data , BSS_df ):
730+ columns_used = ["A" , "B" ]
731+ BSS_df_trimmed = BSS_df [columns_used ]
732+ subplot_division = [columns_used ]
733+ ax = BSS_df_trimmed .plot (subplots = subplot_division , kind = "bar" , stacked = True )
734+ subplot_data_df_list = _BSS_xyheight_from_ax_helper (BSS_data , ax , subplot_division )
735+ for i in range (len (subplot_data_df_list )):
736+ _BSS_subplot_checker (BSS_data , BSS_df_trimmed , subplot_data_df_list [i ], subplot_division [i ])
737+ plt .savefig ("1s1d.png" )
738+
739+
740+ def test_bar_2_subplot_1_double_stacked (self , BSS_data , BSS_df ):
741+ columns_used = ["A" , "B" , "C" ]
742+ BSS_df_trimmed = BSS_df [columns_used ]
743+ subplot_division = [(columns_used [0 ], columns_used [1 ]), (columns_used [2 ],)]
744+ ax = BSS_df_trimmed .plot (subplots = subplot_division , kind = "bar" , stacked = True )
745+ subplot_data_df_list = _BSS_xyheight_from_ax_helper (BSS_data , ax , subplot_division )
746+ for i in range (len (subplot_data_df_list )):
747+ _BSS_subplot_checker (BSS_data , BSS_df_trimmed , subplot_data_df_list [i ], subplot_division [i ])
748+ plt .savefig ("2s1d.png" )
749+
750+ def test_bar_2_subplot_2_double_stacked (self , BSS_data , BSS_df ):
751+ subplot_division = [('A' , 'D' ), ('C' , 'B' )]
752+ ax = BSS_df .plot (subplots = subplot_division , kind = "bar" , stacked = True )
753+ subplot_data_df_list = _BSS_xyheight_from_ax_helper (BSS_data , ax , subplot_division )
754+ for i in range (len (subplot_data_df_list )):
755+ _BSS_subplot_checker (BSS_data , BSS_df , subplot_data_df_list [i ], subplot_division [i ])
756+ plt .savefig ("2s2d.png" )
757+
758+ def test_bar_2_subplots_1_triple_stacked (self , BSS_data , BSS_df ):
759+ subplot_division = [('A' , 'D' , 'C' )]
760+ ax = BSS_df .plot (subplots = subplot_division , kind = "bar" , stacked = True )
761+ subplot_data_df_list = _BSS_xyheight_from_ax_helper (BSS_data , ax , subplot_division )
762+ for i in range (len (subplot_data_df_list )):
763+ _BSS_subplot_checker (BSS_data , BSS_df , subplot_data_df_list [i ], subplot_division [i ])
764+ plt .savefig ("2s1t.png" )
0 commit comments