@@ -682,93 +682,128 @@ def test_bar_plt_xaxis_intervalrange(self):
682682 for a , b in zip (s .plot .bar ().get_xticklabels (), expected )
683683 )
684684
685+
685686@pytest .fixture (scope = "class" )
686687def BSS_data () -> np .array :
687- yield np .random .default_rng (3 ).integers (0 ,100 ,5 )
688+ return np .random .default_rng (3 ).integers (0 , 100 , 5 )
689+
688690
689691@pytest .fixture (scope = "class" )
690692def BSS_df (BSS_data ) -> DataFrame :
691- BSS_df = DataFrame ({"A" : BSS_data , "B" : BSS_data [::- 1 ], "C" : BSS_data [0 ], "D" : BSS_data [- 1 ]})
693+ BSS_df = DataFrame (
694+ {"A" : BSS_data , "B" : BSS_data [::- 1 ], "C" : BSS_data [0 ], "D" : BSS_data [- 1 ]}
695+ )
692696 return BSS_df
693697
698+
694699def _BSS_xyheight_from_ax_helper (BSS_data , ax , subplot_division ):
695700 subplot_data_df_list = []
696701
697- # get xy and height of squares that represent the data graphed from the df, seperated by subplots
702+ # get xy and height of squares representing data, separated by subplots
698703 for i in range (len (subplot_division )):
699- subplot_data = np .array ([(x .get_x (), x .get_y (), x .get_height ()) for x in ax [i ].findobj (plt .Rectangle ) if x .get_height () in BSS_data ])
700- subplot_data_df_list .append (DataFrame (data = subplot_data , columns = ["x_coord" , "y_coord" , "height" ]))
704+ subplot_data = np .array (
705+ [
706+ (x .get_x (), x .get_y (), x .get_height ())
707+ for x in ax [i ].findobj (plt .Rectangle )
708+ if x .get_height () in BSS_data
709+ ]
710+ )
711+ subplot_data_df_list .append (
712+ DataFrame (data = subplot_data , columns = ["x_coord" , "y_coord" , "height" ])
713+ )
701714
702715 return subplot_data_df_list
703716
717+
704718def _BSS_subplot_checker (BSS_data , BSS_df , subplot_data_df , subplot_columns ):
705- assert_flag = 0
706- subplot_sliced_by_source = [subplot_data_df .iloc [len (BSS_data ) * i : len (BSS_data ) * (i + 1 )].reset_index () for i in range (0 , len (subplot_columns ))]
707- expected_total_height = BSS_df .loc [:,subplot_columns ].sum (axis = 1 )
708-
719+ subplot_sliced_by_source = [
720+ subplot_data_df .iloc [len (BSS_data ) * i : len (BSS_data ) * (i + 1 )].reset_index ()
721+ for i in range (len (subplot_columns ))
722+ ]
723+ expected_total_height = BSS_df .loc [:, subplot_columns ].sum (axis = 1 )
724+
709725 for i in range (len (subplot_columns )):
710726 sliced_df = subplot_sliced_by_source [i ]
711727 if i == 0 :
712- #Checks that the bar chart starts y=0
728+ # Checks that the bar chart starts y=0
713729 assert (sliced_df ["y_coord" ] == 0 ).all
714730 height_iter = sliced_df ["y_coord" ].add (sliced_df ["height" ])
715731 else :
716732 height_iter = height_iter + sliced_df ["height" ]
717733
718- if i + 1 == len (subplot_columns ):
719- #Checks final height matches what is expected
720- tm .assert_series_equal (height_iter , expected_total_height , check_names = False , check_dtype = False )
721-
734+ if i + 1 == len (subplot_columns ):
735+ # Checks final height matches what is expected
736+ tm .assert_series_equal (
737+ height_iter , expected_total_height , check_names = False , check_dtype = False
738+ )
739+
722740 else :
723- #Checks each preceding bar ends where the next one starts
724- next_start_coord = subplot_sliced_by_source [i + 1 ]["y_coord" ]
725- tm .assert_series_equal (height_iter , next_start_coord , check_names = False , check_dtype = False )
741+ # Checks each preceding bar ends where the next one starts
742+ next_start_coord = subplot_sliced_by_source [i + 1 ]["y_coord" ]
743+ tm .assert_series_equal (
744+ height_iter , next_start_coord , check_names = False , check_dtype = False
745+ )
746+
726747
727748class TestBarSubplotStacked :
728- #GH Issue 61018
729- @pytest .mark .parametrize ("columns_used" ,[["A" , "B" ],
730- ["C" , "D" ],
731- ["D" , "A" ]
732- ])
749+ # GH Issue 61018
750+ @pytest .mark .parametrize ("columns_used" , [["A" , "B" ], ["C" , "D" ], ["D" , "A" ]])
733751 def test_bar_1_subplot_1_double_stacked (self , BSS_data , BSS_df , columns_used ):
734752 BSS_df_trimmed = BSS_df [columns_used ]
735753 subplot_division = [columns_used ]
736- ax = BSS_df_trimmed .plot (subplots = subplot_division , kind = "bar" , stacked = True )
737- subplot_data_df_list = _BSS_xyheight_from_ax_helper (BSS_data , ax , subplot_division )
754+ ax = BSS_df_trimmed .plot (subplots = subplot_division , kind = "bar" , stacked = True )
755+ subplot_data_df_list = _BSS_xyheight_from_ax_helper (
756+ BSS_data , ax , subplot_division
757+ )
738758 for i in range (len (subplot_data_df_list )):
739- _BSS_subplot_checker (BSS_data , BSS_df_trimmed , subplot_data_df_list [i ], subplot_division [i ])
759+ _BSS_subplot_checker (
760+ BSS_data , BSS_df_trimmed , subplot_data_df_list [i ], subplot_division [i ]
761+ )
740762
741- @pytest .mark .parametrize ("columns_used" ,[["A" , "B" , "C" ],
742- ["A" , "C" , "B" ],
743- ["D" , "A" , "C" ]
744-
745- ])
763+ @pytest .mark .parametrize (
764+ "columns_used" , [["A" , "B" , "C" ], ["A" , "C" , "B" ], ["D" , "A" , "C" ]]
765+ )
746766 def test_bar_2_subplot_1_double_stacked (self , BSS_data , BSS_df , columns_used ):
747- BSS_df_trimmed = BSS_df [columns_used ]
767+ BSS_df_trimmed = BSS_df [columns_used ]
748768 subplot_division = [(columns_used [0 ], columns_used [1 ]), (columns_used [2 ],)]
749- ax = BSS_df_trimmed .plot (subplots = subplot_division , kind = "bar" , stacked = True )
750- subplot_data_df_list = _BSS_xyheight_from_ax_helper (BSS_data , ax , subplot_division )
769+ ax = BSS_df_trimmed .plot (subplots = subplot_division , kind = "bar" , stacked = True )
770+ subplot_data_df_list = _BSS_xyheight_from_ax_helper (
771+ BSS_data , ax , subplot_division
772+ )
751773 for i in range (len (subplot_data_df_list )):
752- _BSS_subplot_checker (BSS_data , BSS_df_trimmed , subplot_data_df_list [i ], subplot_division [i ])
774+ _BSS_subplot_checker (
775+ BSS_data , BSS_df_trimmed , subplot_data_df_list [i ], subplot_division [i ]
776+ )
753777
754- @pytest .mark .parametrize ("subplot_division" , [[("A" , "B" ), ("C" , "D" )],
755- [("A" , "D" ), ("C" , "B" )],
756- [("B" , "C" ), ("D" , "A" )],
757- [("B" , "D" ), ("C" , "A" )]
758- ])
778+ @pytest .mark .parametrize (
779+ "subplot_division" ,
780+ [
781+ [("A" , "B" ), ("C" , "D" )],
782+ [("A" , "D" ), ("C" , "B" )],
783+ [("B" , "C" ), ("D" , "A" )],
784+ [("B" , "D" ), ("C" , "A" )],
785+ ],
786+ )
759787 def test_bar_2_subplot_2_double_stacked (self , BSS_data , BSS_df , subplot_division ):
760- ax = BSS_df .plot (subplots = subplot_division , kind = "bar" , stacked = True )
761- subplot_data_df_list = _BSS_xyheight_from_ax_helper (BSS_data , ax , subplot_division )
788+ ax = BSS_df .plot (subplots = subplot_division , kind = "bar" , stacked = True )
789+ subplot_data_df_list = _BSS_xyheight_from_ax_helper (
790+ BSS_data , ax , subplot_division
791+ )
762792 for i in range (len (subplot_data_df_list )):
763- _BSS_subplot_checker (BSS_data , BSS_df , subplot_data_df_list [i ], subplot_division [i ])
764-
765- @pytest .mark .parametrize ("subplot_division" , [[("A" , "B" , "C" )],
766- [("A" , "D" , "B" )],
767- [("C" , "A" , "D" )],
768- [("D" , "C" , "A" )]
769- ])
793+ _BSS_subplot_checker (
794+ BSS_data , BSS_df , subplot_data_df_list [i ], subplot_division [i ]
795+ )
796+
797+ @pytest .mark .parametrize (
798+ "subplot_division" ,
799+ [[("A" , "B" , "C" )], [("A" , "D" , "B" )], [("C" , "A" , "D" )], [("D" , "C" , "A" )]],
800+ )
770801 def test_bar_2_subplots_1_triple_stacked (self , BSS_data , BSS_df , subplot_division ):
771- ax = BSS_df .plot (subplots = subplot_division , kind = "bar" , stacked = True )
772- subplot_data_df_list = _BSS_xyheight_from_ax_helper (BSS_data , ax , subplot_division )
802+ ax = BSS_df .plot (subplots = subplot_division , kind = "bar" , stacked = True )
803+ subplot_data_df_list = _BSS_xyheight_from_ax_helper (
804+ BSS_data , ax , subplot_division
805+ )
773806 for i in range (len (subplot_data_df_list )):
774- _BSS_subplot_checker (BSS_data , BSS_df , subplot_data_df_list [i ], subplot_division [i ])
807+ _BSS_subplot_checker (
808+ BSS_data , BSS_df , subplot_data_df_list [i ], subplot_division [i ]
809+ )
0 commit comments