Skip to content

Commit a7e3fd4

Browse files
sahil350meta-codesync[bot]
authored andcommitted
Refactor _auto_bar_width and add _auto_bar_width_columnar (#348)
Summary: Refactored the `_auto_bar_width` function to simplify its interface by removing the unused `n_datasets` parameter, and extracted the columnar layout logic into a new dedicated `_auto_bar_width_columnar` function. The original `_auto_bar_width` was trying to serve two different purposes but the `n_datasets` parameter wasn't actually used in the calculation. This change: - Simplifies `_auto_bar_width` for single-bar-per-line layouts (grouped barplots and histograms) - Adds `_auto_bar_width_columnar` specifically for side-by-side columnar layouts used by `ascii_comparative_hist` - Improves code clarity by having separate functions with clear docstrings for each use case - Removed Overriding n_bins and bar_width example from balance_ascii_plots.ipynb Differential Revision: D94465988 Pulled By: sahil350 fbshipit-source-id: fad39e4a7ebbf90dbc9e3c9d2df28672147c7145
1 parent bd1c5bc commit a7e3fd4

File tree

5 files changed

+240
-97
lines changed

5 files changed

+240
-97
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@
1717
and `]` for deficit relative to the baseline. Uses visually distinct fill
1818
characters (``, ``, ``, ``) for better readability when comparing
1919
multiple datasets.
20+
- **ASCII plot display order: population → adjusted → sample**
21+
- Comparative ASCII plots now order datasets as population, adjusted, sample
22+
(instead of sample, adjusted, population) so the target distribution appears
23+
first.
24+
- **`comparative` parameter for ASCII numeric plots**
25+
- `ascii_plot_dist` (and `BalanceDF.plot(library="balance")`) accepts a new
26+
`comparative` keyword (default `True`). When `True`, numeric variables use
27+
the columnar comparative histogram (`ascii_comparative_hist`). Set
28+
`comparative=False` to use grouped-bar histograms (`ascii_plot_hist`) for
29+
numeric variables instead, matching the style used for categorical variables.
2030

2131
## Documentation
2232

balance/balancedf_class.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,11 @@ def plot(
626626
can be used for ASCII text output suitable for LLM consumption (only dist_type="hist_ascii" is supported
627627
with library="balance").
628628
629+
When using ``library="balance"``, numeric variables are rendered as
630+
comparative histograms by default (showing excess/deficit vs. a
631+
baseline). Pass ``comparative=False`` to use grouped-bar histograms
632+
instead (same style as categorical variables).
633+
629634
This function is inherited as is when invoking BalanceDFCovars.plot, but some modifications are made when
630635
preparing the data for BalanceDFOutcomes.plot and BalanceDFWeights.plot.
631636
@@ -686,6 +691,9 @@ def plot(
686691
687692
# ASCII text output (suitable for LLM consumption):
688693
s3_null.covars().plot(library = "balance", dist_type = "hist_ascii")
694+
695+
# ASCII with grouped-bar histograms instead of comparative:
696+
s3_null.covars().plot(library = "balance", comparative = False)
689697
"""
690698
if on_linked_samples:
691699
dfs_to_add = self._BalanceDF_child_from_linked_samples()

balance/stats_and_plots/ascii_plots.py

Lines changed: 152 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from __future__ import annotations
99

1010
import logging
11-
from typing import Dict, List, Optional
11+
from typing import Dict, List, Optional, Tuple
1212

1313
import numpy as np
1414
import numpy.typing as npt
@@ -27,6 +27,34 @@
2727
# Each dataset gets a unique character from this list.
2828
BAR_CHARS: List[str] = ["█", "▒", "▐", "░", "▄", "▀"]
2929

30+
# Preferred ordering for comparative plots: population first, then adjusted,
31+
# then sample. Known internal names are placed in this order; any unknown
32+
# names are appended at the end in their original order.
33+
_PREFERRED_NAME_ORDER: List[str] = ["target", "self", "adjusted", "unadjusted"]
34+
35+
36+
def _reorder_dfs_and_names(
37+
dfs: List[DataFrameWithWeight],
38+
names: List[str],
39+
) -> Tuple[List[DataFrameWithWeight], List[str]]:
40+
"""Reorder *dfs* and *names* to the preferred display order.
41+
42+
The canonical display order is: population (``target``), adjusted
43+
(``self`` when ``unadjusted`` is also present), sample
44+
(``unadjusted``). Names not in the preferred list keep their
45+
original relative order and are appended after the known names.
46+
"""
47+
order_map = {name: i for i, name in enumerate(_PREFERRED_NAME_ORDER)}
48+
indexed = list(enumerate(names))
49+
# Stable sort: known names by preferred position, unknown names stay in
50+
# their original order at the end.
51+
indexed.sort(key=lambda x: (order_map.get(x[1], len(_PREFERRED_NAME_ORDER)), x[0]))
52+
reordered_indices = [i for i, _ in indexed]
53+
return (
54+
[dfs[i] for i in reordered_indices],
55+
[names[i] for i in reordered_indices],
56+
)
57+
3058

3159
def _auto_n_bins(n_samples: int, n_unique: int) -> int:
3260
"""Pick a number of bins using Sturges' rule, capped at unique values."""
@@ -39,8 +67,12 @@ def _auto_n_bins(n_samples: int, n_unique: int) -> int:
3967
return max(2, min(sturges, n_unique, 50))
4068

4169

42-
def _auto_bar_width(label_width: int, n_datasets: int) -> int:
43-
"""Pick bar_width to fit within terminal width."""
70+
def _auto_bar_width(label_width: int) -> int:
71+
"""Pick bar_width to fit within terminal width.
72+
73+
Used by grouped barplots and histograms where each dataset gets its own
74+
line within a row (single bar per line).
75+
"""
4476
import shutil
4577

4678
term_width = shutil.get_terminal_size((80, 24)).columns
@@ -49,6 +81,24 @@ def _auto_bar_width(label_width: int, n_datasets: int) -> int:
4981
return max(10, available)
5082

5183

84+
def _auto_bar_width_columnar(range_width: int, n_columns: int) -> int:
85+
"""Pick per-column bar_width for a columnar (side-by-side) layout.
86+
87+
Used by :func:`ascii_comparative_hist` where all datasets are rendered as
88+
columns on the same line. Each column needs space for the bar, a
89+
percentage string (~6 chars), and inter-column separators (`` | ``, 3
90+
chars each).
91+
"""
92+
import shutil
93+
94+
term_width = shutil.get_terminal_size((80, 24)).columns
95+
# "Range | col1 | col2 | ..."
96+
# range_width + " | " (3+1 for padding) consumed by the label column
97+
available = term_width - range_width - 4
98+
per_col = max(10, (available - (n_columns - 1) * 3) // n_columns - 6)
99+
return per_col
100+
101+
52102
def _weighted_histogram(
53103
values: pd.Series,
54104
weights: Optional[pd.Series],
@@ -174,7 +224,8 @@ def ascii_plot_bar(
174224
names: Names for each DataFrame (e.g., ["self", "target"]).
175225
column: The categorical column name to plot.
176226
weighted: Whether to use weights. Defaults to True.
177-
bar_width: Maximum character width for bars. Defaults to 40.
227+
bar_width: Maximum character width for bars. Defaults to None,
228+
which auto-detects based on terminal width.
178229
dist_type: Accepted for compatibility but only "hist_ascii" is supported.
179230
A warning is logged if any other value is passed.
180231
separate_categories: If True, insert a blank line between categories
@@ -243,7 +294,7 @@ def ascii_plot_bar(
243294
label_width = max(label_width, 8) # minimum width for "Category"
244295

245296
if bar_width is None:
246-
bar_width = _auto_bar_width(label_width, len(legend_names))
297+
bar_width = _auto_bar_width(label_width)
247298

248299
# Build output
249300
lines: List[str] = []
@@ -304,8 +355,10 @@ def ascii_plot_hist(
304355
names: Names for each DataFrame (e.g., ["self", "target"]).
305356
column: The numeric column name to plot.
306357
weighted: Whether to use weights. Defaults to True.
307-
n_bins: Number of histogram bins. Defaults to 10.
308-
bar_width: Maximum character width for bars. Defaults to 40.
358+
n_bins: Number of histogram bins. Defaults to None, which
359+
auto-detects using Sturges' rule.
360+
bar_width: Maximum character width for bars. Defaults to None,
361+
which auto-detects based on terminal width.
309362
dist_type: Accepted for compatibility but only "hist_ascii" is supported.
310363
A warning is logged if any other value is passed.
311364
@@ -395,7 +448,7 @@ def ascii_plot_hist(
395448
label_width = max(label_width, 3) # minimum width for "Bin"
396449

397450
if bar_width is None:
398-
bar_width = _auto_bar_width(label_width, len(legend_names))
451+
bar_width = _auto_bar_width(label_width)
399452

400453
# Build output
401454
lines: List[str] = []
@@ -459,8 +512,10 @@ def ascii_comparative_hist(
459512
names: Names for each DataFrame (e.g., ["Target", "Sample"]).
460513
column: The numeric column name to plot.
461514
weighted: Whether to use weights. Defaults to True.
462-
n_bins: Number of histogram bins. Defaults to 10.
463-
bar_width: Maximum character width for bars. Defaults to 20.
515+
n_bins: Number of histogram bins. Defaults to None, which
516+
auto-detects using Sturges' rule.
517+
bar_width: Maximum character width for bars. Defaults to None,
518+
which auto-detects based on terminal width.
464519
465520
Returns:
466521
ASCII comparative histogram text.
@@ -470,6 +525,8 @@ def ascii_comparative_hist(
470525
471526
>>> print(ascii_comparative_hist(dfs, names=["Target", "Sample"],
472527
... column="income", n_bins=2, bar_width=20))
528+
=== income (numeric, comparative) ===
529+
<BLANKLINE>
473530
Range | Target (%) | Sample (%)
474531
---------------------------------------------------------------
475532
[10.00, 25.00) | █████████████ 50.0 | █████████████▒▒▒▒▒▒▒ 75.0
@@ -538,14 +595,7 @@ def ascii_comparative_hist(
538595
range_width = max(len(range_header), max(len(lbl) for lbl in bin_labels))
539596

540597
if bar_width is None:
541-
import shutil
542-
543-
term_width = shutil.get_terminal_size((80, 24)).columns
544-
n_cols = len(legend_names)
545-
# Each column needs: bar_width + pct string (~6) + spacing (3)
546-
available = term_width - range_width - 4 # " | " separator
547-
per_col = max(10, (available - (n_cols - 1) * 3) // n_cols - 6)
548-
bar_width = per_col
598+
bar_width = _auto_bar_width_columnar(range_width, len(legend_names))
549599

550600
# Baseline percentages (first dataset)
551601
baseline_pcts = hist_pcts[0]
@@ -597,6 +647,8 @@ def ascii_comparative_hist(
597647

598648
# Build output
599649
lines: List[str] = []
650+
lines.append(f"=== {column} (numeric, comparative) ===")
651+
lines.append("")
600652

601653
# Header row
602654
header_parts = [range_header.ljust(range_width)]
@@ -645,12 +697,27 @@ def ascii_plot_dist(
645697
bar_width: Optional[int] = None,
646698
dist_type: Optional[str] = None,
647699
separate_categories: bool = True,
700+
comparative: bool = True,
648701
) -> str:
649702
"""Produces ASCII text comparing weighted distributions across datasets.
650703
651704
Iterates over variables, classifying each as categorical or numeric
652705
(using the same logic as :func:`seaborn_plot_dist`), then delegates to
653-
:func:`ascii_plot_bar` or :func:`ascii_plot_hist` respectively.
706+
the appropriate plotting function.
707+
708+
Two display modes are available for numeric variables:
709+
710+
- **comparative** (``comparative=True``, the default): numeric variables
711+
are rendered with :func:`ascii_comparative_hist`, a columnar layout
712+
where the first dataset is the baseline and subsequent datasets show
713+
excess / deficit relative to it.
714+
- **grouped** (``comparative=False``): numeric variables are rendered
715+
with :func:`ascii_plot_hist`, a grouped-bar layout where each dataset
716+
gets its own bar per bin (the same style used for categorical
717+
variables).
718+
719+
Categorical variables always use :func:`ascii_plot_bar` regardless of
720+
this setting.
654721
655722
The output is both printed to stdout and returned as a string.
656723
@@ -662,12 +729,19 @@ def ascii_plot_dist(
662729
numeric_n_values_threshold: Columns with fewer unique values than this
663730
are treated as categorical. Defaults to 15.
664731
weighted: Whether to use weights. Defaults to True.
665-
n_bins: Number of bins for numeric histograms. Defaults to 10.
666-
bar_width: Maximum character width for the longest bar. Defaults to 40.
732+
n_bins: Number of bins for numeric histograms. Defaults to None,
733+
which auto-detects using Sturges' rule.
734+
bar_width: Maximum character width for the longest bar. Defaults to
735+
None, which auto-detects based on terminal width.
667736
dist_type: Accepted for compatibility but only "hist_ascii" is supported.
668737
A warning is logged if any other value is passed.
669738
separate_categories: If True, insert a blank line between categories
670739
in barplots for readability. Defaults to True.
740+
comparative: If True (default), numeric variables use a columnar
741+
comparative histogram (:func:`ascii_comparative_hist`) that
742+
highlights differences relative to a baseline dataset. If
743+
False, numeric variables use a grouped-bar histogram
744+
(:func:`ascii_plot_hist`) instead.
671745
672746
Returns:
673747
The full ASCII output text.
@@ -693,31 +767,52 @@ def ascii_plot_dist(
693767
... numeric_n_values_threshold=0, n_bins=2, bar_width=20))
694768
=== color (categorical) ===
695769
<BLANKLINE>
696-
Category | sample population
770+
Category | population sample
697771
|
698-
blue | ████████████████████ (50.0%)
699-
| ▒▒▒▒▒▒▒▒▒▒ (25.0%)
772+
blue | ██████████ (25.0%)
773+
| ▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒ (50.0%)
700774
<BLANKLINE>
701775
green | ██████████ (25.0%)
702776
| ▒▒▒▒▒▒▒▒▒▒ (25.0%)
703777
<BLANKLINE>
704-
red | ██████████ (25.0%)
705-
| ▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒ (50.0%)
778+
red | ████████████████████ (50.0%)
779+
| ▒▒▒▒▒▒▒▒▒▒ (25.0%)
706780
<BLANKLINE>
707-
Legend: █ samplepopulation
781+
Legend: █ populationsample
708782
Bar lengths are proportional to weighted frequency within each dataset.
709783
<BLANKLINE>
784+
=== age (numeric, comparative) ===
785+
<BLANKLINE>
786+
Range | population (%) | sample (%)
787+
---------------------------------------------------------------
788+
[10.00, 25.00) | █████████████ 50.0 | █████████████▒▒▒▒▒▒▒ 75.0
789+
[25.00, 40.00] | █████████████ 50.0 | ███████ ] 25.0
790+
---------------------------------------------------------------
791+
Total | 100.0 | 100.0
792+
<BLANKLINE>
793+
Key: █ = shared with population, ▒ = excess, ] = deficit
794+
795+
To use grouped-bar histograms (same style as categorical) instead
796+
of comparative histograms for numeric variables, pass
797+
``comparative=False``::
798+
799+
>>> print(ascii_plot_dist(dfs, names=["self", "target"],
800+
... numeric_n_values_threshold=0, n_bins=2, bar_width=20,
801+
... comparative=False))
802+
=== color (categorical) ===
803+
...
710804
=== age (numeric) ===
711805
<BLANKLINE>
712-
Bin | sample population
806+
Bin | population sample
713807
|
714-
[10.00, 25.00) | █████████████ (50.0%)
715-
| ▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒ (75.0%)
716-
[25.00, 40.00] | █████████████ (50.0%)
717-
| ▒▒▒▒▒▒▒ (25.0%)
808+
[10.00, 25.00) | ████████████████████ (75.0%)
809+
| ▒▒▒▒▒▒▒▒▒▒▒▒▒ (50.0%)
810+
[25.00, 40.00] | ███████ (25.0%)
811+
| ▒▒▒▒▒▒▒▒▒▒▒▒▒ (50.0%)
718812
<BLANKLINE>
719-
Legend: █ samplepopulation
813+
Legend: █ populationsample
720814
Bar lengths are proportional to weighted frequency within each dataset.
815+
<BLANKLINE>
721816
"""
722817
if dist_type is not None and dist_type != "hist_ascii":
723818
logger.warning(
@@ -727,6 +822,9 @@ def ascii_plot_dist(
727822
if names is None:
728823
names = [f"df_{i}" for i in range(len(dfs))]
729824

825+
# Reorder so comparative plots show: population, adjusted, sample
826+
dfs, names = _reorder_dfs_and_names(dfs, names)
827+
730828
variables = choose_variables(*(d["df"] for d in dfs), variables=variables)
731829
logger.debug(f"ASCII plotting variables {variables}")
732830

@@ -757,16 +855,28 @@ def ascii_plot_dist(
757855
)
758856
)
759857
else:
760-
output_parts.append(
761-
ascii_plot_hist(
762-
dfs,
763-
names,
764-
o,
765-
weighted=weighted,
766-
n_bins=n_bins,
767-
bar_width=bar_width,
858+
if comparative:
859+
output_parts.append(
860+
ascii_comparative_hist(
861+
dfs,
862+
names,
863+
o,
864+
weighted=weighted,
865+
n_bins=n_bins,
866+
bar_width=bar_width,
867+
)
868+
)
869+
else:
870+
output_parts.append(
871+
ascii_plot_hist(
872+
dfs,
873+
names,
874+
o,
875+
weighted=weighted,
876+
n_bins=n_bins,
877+
bar_width=bar_width,
878+
)
768879
)
769-
)
770880

771881
result = "\n".join(output_parts)
772882
print(result)

0 commit comments

Comments
 (0)