44from  pydantic  import  BaseModel , Field 
55
66from  .base  import  Graph2D , GraphType 
7+ from  ..utils .rounding  import  dynamic_round 
78
89
910class  BarData (BaseModel ):
@@ -48,6 +49,7 @@ class BoxAndWhiskerData(BaseModel):
4849    median : float 
4950    third_quartile : float 
5051    max : float 
52+     outliers : List [float ]
5153
5254
5355class  BoxAndWhiskerGraph (Graph2D ):
@@ -56,22 +58,23 @@ class BoxAndWhiskerGraph(Graph2D):
5658    elements : List [BoxAndWhiskerData ] =  Field (default_factory = list )
5759
5860    def  _extract_info (self , ax : Axes ) ->  None :
59-         super (). _extract_info ( ax ) 
61+         labels   =  [ item . get_text ()  for   item   in   ax . get_xticklabels ()] 
6062
6163        boxes  =  []
62-         for  box  in  ax .patches :
64+         for  label ,  box  in  zip ( labels ,  ax .patches ) :
6365            vertices  =  box .get_path ().vertices 
64-             x_vertices  =  vertices [:, 0 ]
65-             y_vertices  =  vertices [:, 1 ]
66+             x_vertices  =  [ dynamic_round ( x )  for   x   in   vertices [:, 0 ] ]
67+             y_vertices  =  [ dynamic_round ( y )  for   y   in   vertices [:, 1 ] ]
6668            x  =  min (x_vertices )
6769            y  =  min (y_vertices )
6870            boxes .append (
6971                {
7072                    "x" : x ,
7173                    "y" : y ,
72-                     "label" : box .get_label (),
73-                     "width" : round (max (x_vertices ) -  x , 4 ),
74-                     "height" : round (max (y_vertices ) -  y , 4 ),
74+                     "label" : label ,
75+                     "width" : max (x_vertices ) -  x ,
76+                     "height" : max (y_vertices ) -  y ,
77+                     "outliers" : [],
7578                }
7679            )
7780
@@ -85,13 +88,21 @@ def _extract_info(self, ax: Axes) -> None:
8588                box ["x" ], box ["y" ] =  box ["y" ], box ["x" ]
8689                box ["width" ], box ["height" ] =  box ["height" ], box ["width" ]
8790
88-         for  line  in  ax .lines :
89-             xdata  =  line .get_xdata ()
90-             ydata  =  line .get_ydata ()
91+         for  i ,  line  in  enumerate ( ax .lines ) :
92+             xdata  =  [ dynamic_round ( x )  for   x   in   line .get_xdata ()] 
93+             ydata  =  [ dynamic_round ( y )  for   y   in   line .get_ydata ()] 
9194
9295            if  orientation  ==  "vertical" :
9396                xdata , ydata  =  ydata , xdata 
9497
98+             if  len (xdata ) ==  1 :
99+                 for  box  in  boxes :
100+                     if  box ["x" ] <=  xdata [0 ] <=  box ["x" ] +  box ["width" ]:
101+                         break 
102+                 else :
103+                     continue 
104+ 
105+                 box ["outliers" ].append (ydata [0 ])
95106            if  len (ydata ) !=  2 :
96107                continue 
97108            for  box  in  boxes :
@@ -101,6 +112,7 @@ def _extract_info(self, ax: Axes) -> None:
101112                continue 
102113
103114            if  (
115+                 # Check if the line is inside the box, prevent floating point errors 
104116                ydata [0 ] ==  ydata [1 ]
105117                and  box ["y" ] <=  ydata [0 ] <=  box ["y" ] +  box ["height" ]
106118            ):
@@ -122,6 +134,7 @@ def _extract_info(self, ax: Axes) -> None:
122134                median = box ["median" ],
123135                third_quartile = box ["y" ] +  box ["height" ],
124136                max = box ["whisker_upper" ],
137+                 outliers = box ["outliers" ],
125138            )
126139            for  box  in  boxes 
127140        ]
0 commit comments