@@ -65,11 +65,12 @@ def _truncate_top_k_categories(
6565 {v : ellide_string (v , max_len = 20 ) for v in values if isinstance (v , str )}
6666 )
6767 else :
68+ original_dtype = col .dtype
6869 col [~ keep ] = other_label
6970 col = col .apply (
7071 lambda x : ellide_string (x , max_len = 20 ) if isinstance (x , str ) else x
7172 )
72- col = col .astype (object )
73+ col = col .astype (original_dtype )
7374 return col
7475
7576
@@ -290,7 +291,7 @@ def _plot_matplotlib(
290291 hue : str | None = None ,
291292 kind : Literal ["dist" , "corr" ] = "dist" ,
292293 top_k_categories : int = 20 ,
293- ) -> None :
294+ ) -> tuple [ Figure , Axes ] :
294295 """Matplotlib implementation of the `plot` method."""
295296 self .figure_ , self .ax_ = plt .subplots ()
296297 if kind == "dist" :
@@ -306,6 +307,8 @@ def _plot_matplotlib(
306307 y = y ,
307308 k = top_k_categories ,
308309 histplot_kwargs = self ._default_histplot_kwargs ,
310+ figure = self .figure_ ,
311+ ax = self .ax_ ,
309312 )
310313 case _:
311314 self ._plot_distribution_2d (
@@ -317,6 +320,8 @@ def _plot_matplotlib(
317320 stripplot_kwargs = self ._default_stripplot_kwargs ,
318321 boxplot_kwargs = self ._default_boxplot_kwargs ,
319322 heatmap_kwargs = self ._default_heatmap_kwargs ,
323+ figure = self .figure_ ,
324+ ax = self .ax_ ,
320325 )
321326
322327 elif kind == "corr" :
@@ -327,18 +332,22 @@ def _plot_matplotlib(
327332 raise ValueError (
328333 f"When { kind = !r} , { param_name !r} argument must be None."
329334 )
330- self ._plot_cramer (heatmap_kwargs = self ._default_heatmap_kwargs )
335+ self ._plot_cramer (heatmap_kwargs = self ._default_heatmap_kwargs , ax = self . ax_ )
331336
332337 else :
333338 raise ValueError (f"'kind' options are 'dist', 'corr', got { kind !r} ." )
334339
340+ return (self .figure_ , self .ax_ )
341+
335342 def _plot_distribution_1d (
336343 self ,
337344 * ,
338345 x : str | None ,
339346 y : str | None ,
340347 k : int ,
341348 histplot_kwargs : dict [str , Any ],
349+ figure : Figure ,
350+ ax : Axes ,
342351 ) -> None :
343352 """Plot 1-dimensional distribution of a feature.
344353
@@ -387,16 +396,16 @@ def _plot_distribution_1d(
387396 histplot_params = {"x" : column }
388397 despine_params = {"bottom" : is_categorical }
389398 if duration_unit is not None :
390- self . ax_ .set (xlabel = f"{ duration_unit .capitalize ()} s" )
399+ ax .set (xlabel = f"{ duration_unit .capitalize ()} s" )
391400 else : # y is not None
392401 histplot_params = {"y" : column }
393402 despine_params = {"left" : is_categorical }
394403 if duration_unit is not None :
395- self . ax_ .set (ylabel = f"{ duration_unit .capitalize ()} s" )
404+ ax .set (ylabel = f"{ duration_unit .capitalize ()} s" )
396405
397- sns .histplot (ax = self . ax_ , ** histplot_params , ** histplot_kwargs_validated )
406+ sns .histplot (ax = ax , ** histplot_params , ** histplot_kwargs_validated )
398407 sns .despine (
399- self . figure_ ,
408+ figure ,
400409 top = True ,
401410 right = True ,
402411 trim = True ,
@@ -406,17 +415,17 @@ def _plot_distribution_1d(
406415
407416 if is_categorical :
408417 _resize_categorical_axis (
409- figure = self . figure_ ,
410- ax = self . ax_ ,
418+ figure = figure ,
419+ ax = ax ,
411420 n_categories = sbd .n_unique (column ),
412421 is_x_axis = x is not None ,
413422 )
414423
415424 if x is not None and any (
416- len (label .get_text ()) > 1 for label in self . ax_ .get_xticklabels ()
425+ len (label .get_text ()) > 1 for label in ax .get_xticklabels ()
417426 ):
418427 # rotate only for string longer than 1 character
419- _rotate_ticklabels (self . ax_ , rotation = 45 )
428+ _rotate_ticklabels (ax , rotation = 45 )
420429
421430 def _plot_distribution_2d (
422431 self ,
@@ -429,6 +438,8 @@ def _plot_distribution_2d(
429438 scatterplot_kwargs : dict [str , Any ],
430439 hue : str | None = None ,
431440 k : int = 20 ,
441+ figure : Figure ,
442+ ax : Axes ,
432443 ) -> None :
433444 """Plot 2-dimensional distribution of two features.
434445
@@ -478,7 +489,7 @@ def _plot_distribution_2d(
478489 x = x ,
479490 y = y ,
480491 hue = hue ,
481- ax = self . ax_ ,
492+ ax = ax ,
482493 ** scatterplot_kwargs_validated ,
483494 )
484495 elif is_x_num or is_y_num :
@@ -512,23 +523,21 @@ def _plot_distribution_2d(
512523 else :
513524 x = _truncate_top_k_categories (x , k )
514525
515- sns .boxplot (x = x , y = y , ax = self . ax_ , ** boxplot_kwargs_validated )
516- sns .stripplot (x = x , y = y , hue = hue , ax = self . ax_ , ** stripplot_kwargs_validated )
526+ sns .boxplot (x = x , y = y , ax = ax , ** boxplot_kwargs_validated )
527+ sns .stripplot (x = x , y = y , hue = hue , ax = ax , ** stripplot_kwargs_validated )
517528
518529 _resize_categorical_axis (
519- figure = self . figure_ ,
520- ax = self . ax_ ,
530+ figure = figure ,
531+ ax = ax ,
521532 n_categories = sbd .n_unique (y ) if is_x_num else sbd .n_unique (x ),
522533 is_x_axis = not is_x_num ,
523534 )
524535 if is_x_num :
525536 despine_params ["left" ] = True
526537 else :
527538 despine_params ["bottom" ] = True
528- if any (
529- len (label .get_text ()) > 1 for label in self .ax_ .get_xticklabels ()
530- ):
531- _rotate_ticklabels (self .ax_ , rotation = 45 )
539+ if any (len (label .get_text ()) > 1 for label in ax .get_xticklabels ()):
540+ _rotate_ticklabels (ax , rotation = 45 )
532541 else :
533542 if (hue is not None ) and (not sbd .is_numeric (hue )):
534543 raise ValueError (
@@ -576,9 +585,9 @@ def _plot_distribution_2d(
576585 },
577586 heatmap_kwargs ,
578587 )
579- sns .heatmap (contingency_table , ax = self . ax_ , ** heatmap_kwargs_validated )
588+ sns .heatmap (contingency_table , ax = ax , ** heatmap_kwargs_validated )
580589 despine_params .update (left = True , bottom = True )
581- self . ax_ .tick_params (axis = "both" , length = 0 )
590+ ax .tick_params (axis = "both" , length = 0 )
582591
583592 for is_x_axis , x_or_y in zip (
584593 [True , False ],
@@ -589,26 +598,28 @@ def _plot_distribution_2d(
589598 strict = False ,
590599 ):
591600 _resize_categorical_axis (
592- figure = self . figure_ ,
593- ax = self . ax_ ,
601+ figure = figure ,
602+ ax = ax ,
594603 n_categories = sbd .n_unique (x_or_y ),
595604 is_x_axis = is_x_axis ,
596605 size_per_category = size_per_category ,
597606 )
598607
599- sns .despine (self . figure_ , ** despine_params )
608+ sns .despine (figure , ** despine_params )
600609
601- self . ax_ .set (xlabel = sbd .name (x ), ylabel = sbd .name (y ))
602- if self . ax_ .legend_ is not None :
603- sns .move_legend (self . ax_ , (1.05 , 0.0 ))
610+ ax .set (xlabel = sbd .name (x ), ylabel = sbd .name (y ))
611+ if ax .legend_ is not None :
612+ sns .move_legend (ax , (1.05 , 0.0 ))
604613
605- def _plot_cramer (self , * , heatmap_kwargs : dict [str , Any ]) -> None :
614+ def _plot_cramer (self , * , heatmap_kwargs : dict [str , Any ], ax : Axes ) -> None :
606615 """Plot Cramer's V correlation among all columns.
607616
608617 Parameters
609618 ----------
610619 heatmap_kwargs : dict, default=None
611620 Keyword arguments to be passed to heatmap.
621+ ax : Axes
622+ The axes to plot on.
612623 """
613624 heatmap_kwargs_validated = _validate_style_kwargs (
614625 {
@@ -632,8 +643,8 @@ def _plot_cramer(self, *, heatmap_kwargs: dict[str, Any]) -> None:
632643 # and keep the diagonal as well.
633644 mask = np .triu (np .ones_like (cramer_v_table , dtype = bool ), k = 1 )
634645
635- sns .heatmap (cramer_v_table , mask = mask , ax = self . ax_ , ** heatmap_kwargs_validated )
636- self . ax_ .set (title = "Cramer's V Correlation" )
646+ sns .heatmap (cramer_v_table , mask = mask , ax = ax , ** heatmap_kwargs_validated )
647+ ax .set (title = "Cramer's V Correlation" )
637648
638649 def frame (
639650 self , * , kind : Literal ["dataset" , "top-associations" ] = "dataset"
0 commit comments