Skip to content

Commit 6cb8a56

Browse files
neuralsorcerermeta-codesync[bot]
authored andcommitted
Raw-covariate adjustment for custom models (#323)
Summary: Pull Request resolved: #323 `Sample.adjust()` now supports fitting models on raw covariates (without a model matrix) for IPW via `use_model_matrix=False`. Categorical columns are encoded as integer codes (ordinal encoding). NaN values in categorical columns are assigned a distinct code (one higher than the maximum) rather than being mapped to `-1`. Note that ordinal encoding treats categories as ordered numeric values; for true unordered categorical support, use sklearn 1.4+ with `HistGradientBoostingClassifier` and `categorical_features="from_dtype"`. Pull Request resolved: #321 Reviewed By: omriharosh Differential Revision: D92627587 Pulled By: talgalili fbshipit-source-id: 94d3b7d9e33bdd0b384dd18019e1ef653945d17f
1 parent d54dcc0 commit 6cb8a56

File tree

8 files changed

+557
-98
lines changed

8 files changed

+557
-98
lines changed

.github/workflows/deploy-website.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
- name: Set up Python
1919
uses: actions/setup-python@v5
2020
with:
21-
python-version: 3.9
21+
python-version: "3.12"
2222
- name: Install Pkg + Dependencies
2323
run: |
2424
python -m pip install .[dev]

CHANGELOG.md

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,32 @@
22

33
## New Features
44

5-
- **Validate weights include positive values**
6-
- Added a guard in weight diagnostics to error when all weights are zero.
7-
- **Support configurable ID column candidates**
8-
- `Sample.from_frame()` and `guess_id_column()` now accept candidate ID column names
9-
when auto-detecting the ID column.
105
- **Outcome weight impact diagnostics**
116
- Added paired outcome-weight impact tests (`y*w0` vs `y*w1`) with confidence intervals.
127
- Exposed in `BalanceDFOutcomes`, `Sample.diagnostics()`, and the CLI via
138
`--weights_impact_on_outcome_method`.
149
- **Pandas 3 support**
1510
- Updated compatibility and tests for pandas 3.x
16-
- **Formula support for BalanceDF model matrices**
17-
- `BalanceDF.model_matrix()` now accepts a `formula` argument to build
18-
custom model matrices without precomputing them manually.
1911
- **Categorical distribution metrics without one-hot encoding**
2012
- KLD/EMD/CVMD/KS on `BalanceDF.covars()` now operate on raw categorical variables
2113
(with NA indicators) instead of one-hot encoded columns.
14+
- **Misc**
15+
- **Raw-covariate adjustment for custom models**
16+
- `Sample.adjust()` now supports fitting models on raw covariates (without a model matrix)
17+
for IPW via `use_model_matrix=False`. String, object, and boolean columns are converted
18+
to pandas `Categorical` dtype, allowing sklearn estimators with native categorical
19+
support (e.g., `HistGradientBoostingClassifier` with `categorical_features="from_dtype"`)
20+
to handle them correctly. Requires scikit-learn >= 1.4 when categorical columns are
21+
present.
22+
- **Validate weights include positive values**
23+
- Added a guard in weight diagnostics to error when all weights are zero.
24+
- **Support configurable ID column candidates**
25+
- `Sample.from_frame()` and `guess_id_column()` now accept candidate ID column names
26+
when auto-detecting the ID column.
27+
- **Formula support for BalanceDF model matrices**
28+
- `BalanceDF.model_matrix()` now accepts a `formula` argument to build
29+
custom model matrices without precomputing them manually.
30+
2231

2332
## Bug Fixes
2433

balance/sample_class.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -510,10 +510,9 @@ def from_frame(
510510
]
511511
# TODO:(after 2026) that if pandas >=3, this doesn't cause issues for users importing data from SQL
512512
# In pandas < 3, convert string dtype to object for compatibility
513-
_pd_version = tuple(
514-
int(x) for x in importlib_version("pandas").split(".")[:2]
515-
)
516-
if _pd_version < (3, 0):
513+
from packaging.version import Version
514+
515+
if Version(importlib_version("pandas")) < Version("3.0"):
517516
input_type.append("string")
518517
output_type.append("object")
519518
for i_input, i_output in zip(input_type, output_type):
@@ -940,31 +939,38 @@ def adjust(
940939
.. code-block:: python
941940
942941
import balance
943-
from sklearn.ensemble import RandomForestClassifier
942+
from sklearn.ensemble import HistGradientBoostingClassifier
944943
from balance import Sample
945944
from balance import load_data
946945
947946
# Load simulated data
948947
target_df, sample_df = load_data()
949948
949+
sample
950+
950951
sample = Sample.from_frame(sample_df, outcome_columns=["happiness"])
951-
# Often times we don'y have the outcome for the target. In this case we've added it just to validate later that the weights indeed help us reduce the bias
952+
# Often times we don't have the outcome for the target. In this case we've added it just to validate later that the weights indeed help us reduce the bias
952953
target = Sample.from_frame(target_df, outcome_columns=["happiness"])
953954
954955
sample_with_target = sample.set_target(target)
955956
adjusted = sample_with_target.adjust()
956957
957-
rf = RandomForestClassifier(n_estimators=200, random_state=0)
958-
adjusted_rf = sample_with_target.adjust(model = rf)
958+
hgb = HistGradientBoostingClassifier(
959+
random_state=0, categorical_features="from_dtype"
960+
)
961+
adjusted_hgb = sample_with_target.adjust(
962+
model=hgb,
963+
use_model_matrix=False,
964+
)
959965
960-
# Print ASMD tables for both adjusted and adjusted_rf
966+
# Print ASMD tables for both adjusted and adjusted_hgb
961967
print("\\n=== Adjusted ASMD ===")
962968
print(adjusted.covars().asmd().T)
963969
964-
print("\\n=== Adjusted_RF ASMD ===")
965-
print(adjusted_rf.covars().asmd().T)
970+
print("\\n=== Adjusted_HGB ASMD ===")
971+
print(adjusted_hgb.covars().asmd().T)
966972
967-
# output
973+
# output (values will vary by model and random seed)
968974
#
969975
# === Adjusted ASMD ===
970976
# source self unadjusted unadjusted - self
@@ -977,16 +983,16 @@ def adjust(
977983
# income 0.205469 0.494217 0.288748
978984
# mean(asmd) 0.119597 0.326799 0.207202
979985
#
980-
# === Adjusted_RF ASMD ===
986+
# === Adjusted_HGB ASMD ===
981987
# source self unadjusted unadjusted - self
982-
# age_group[T.25-34] 0.074491 0.005688 -0.068804
983-
# age_group[T.35-44] 0.022383 0.312711 0.290328
984-
# age_group[T.45+] 0.145628 0.378828 0.233201
985-
# gender[Female] 0.037700 0.375699 0.337999
986-
# gender[Male] 0.067392 0.379314 0.311922
987-
# gender[_NA] 0.051718 0.006296 -0.045422
988-
# income 0.140655 0.494217 0.353562
989-
# mean(asmd) 0.091253 0.326799 0.235546
988+
# age_group[T.25-34] ... 0.005688 ...
989+
# age_group[T.35-44] ... 0.312711 ...
990+
# age_group[T.45+] ... 0.378828 ...
991+
# gender[Female] ... 0.375699 ...
992+
# gender[Male] ... 0.379314 ...
993+
# gender[_NA] ... 0.006296 ...
994+
# income ... 0.494217 ...
995+
# mean(asmd) ... 0.326799 ...
990996
"""
991997
if target is None:
992998
self._no_target_error()

balance/utils/data_transformation.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,29 @@ def add_na_indicator_to_combined(df: pd.DataFrame) -> pd.DataFrame:
9191
Returns:
9292
pd.DataFrame: The DataFrame with NA indicator columns added for every
9393
base column that contains missing values.
94+
95+
Examples:
96+
Basic usage on a DataFrame without pre-existing indicators:
97+
98+
>>> import pandas as pd
99+
>>> from balance.utils.data_transformation import add_na_indicator_to_combined
100+
>>> df = pd.DataFrame({"x": [1.0, None, 3.0], "y": [0, 1, 2]})
101+
>>> result = add_na_indicator_to_combined(df)
102+
>>> result.columns.tolist()
103+
['x', 'y', '_is_na_x']
104+
105+
When the input already contains ``_is_na_*`` columns, they are preserved
106+
and not duplicated:
107+
108+
>>> df2 = pd.DataFrame(
109+
... {
110+
... "x": [1.0, None, 3.0],
111+
... "_is_na_y": [0, 1, 0],
112+
... }
113+
... )
114+
>>> result2 = add_na_indicator_to_combined(df2)
115+
>>> result2.columns.tolist()
116+
['x', '_is_na_x', '_is_na_y']
94117
"""
95118
existing_indicator_cols = [
96119
col for col in df.columns if isinstance(col, str) and col.startswith("_is_na_")

0 commit comments

Comments
 (0)