Skip to content

Commit cb99f12

Browse files
neuralsorcerermeta-codesync[bot]
authored andcommitted
Add formula support to BalanceDF model_matrix (#318)
Summary: - Closes #304 Pull Request resolved: #318 Differential Revision: D92394155 Pulled By: talgalili fbshipit-source-id: ec76b8f628e555c15b568f72b5cea9fca00631ea
1 parent 90a3b12 commit cb99f12

File tree

3 files changed

+75
-11
lines changed

3 files changed

+75
-11
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
`--weights_impact_on_outcome_method`.
1414
- **Pandas 3 support**
1515
- Updated compatibility and tests for pandas 3.x
16+
- **Formula support for BalanceDF model matrices**
17+
- `BalanceDF.model_matrix()` now accepts a `formula` argument to build
18+
custom model matrices without precomputing them manually.
1619

1720
## Bug Fixes
1821

balance/balancedf_class.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from balance.typing import FilePathOrBuffer
2929
from balance.util import find_items_index_in_list, get_items_from_list_via_indices
30+
from balance.utils.input_validation import _verify_value_type
3031
from IPython.lib.display import FileLink
3132
from plotly.graph_objs import Figure
3233

@@ -392,18 +393,19 @@ def _call_on_linked(
392393
# if v is not None and k not in exclude
393394
# )
394395

395-
# TODO: add the ability to pass formula argument to model_matrix
396-
# but in which case - notice that we'd want the ability to track
397-
# which object is stored in _model_matrix (and to run it over)
398-
# Also, the output may sometimes no longer only be pd.DataFrame
399-
# so such work will require update the type hinting here.
400-
def model_matrix(self: "BalanceDF") -> pd.DataFrame:
396+
def model_matrix(
397+
self: "BalanceDF", formula: str | list[str] | None = None
398+
) -> pd.DataFrame:
401399
"""Return a model_matrix version of the df inside the BalanceDF object using balance_util.model_matrix
402400
403401
This can be used to turn all character columns into a one hot encoding columns.
404402
405403
Args:
406404
self (BalanceDF): Object
405+
formula (str | list[str] | None, optional): Optional formula string (or list of
406+
formula strings) to pass to :func:`balance_util.model_matrix`. When
407+
provided, the model matrix is computed on demand for the formula and
408+
not cached on the object. Defaults to None.
407409
408410
Returns:
409411
pd.DataFrame: The output from :func:`balance_util.model_matrix`
@@ -443,12 +445,27 @@ def model_matrix(self: "BalanceDF") -> pd.DataFrame:
443445
# 1 2.0 8.0 0.0 0.0 1.0 0.0
444446
# 2 3.0 2.0 0.0 0.0 0.0 1.0
445447
# 3 1.0 -42.0 1.0 0.0 0.0 0.0
448+
449+
print(s1.covars().model_matrix(formula="a + b"))
450+
# a b
451+
# 0 1.0 -42.0
452+
# 1 2.0 8.0
453+
# 2 3.0 2.0
454+
# 3 1.0 -42.0
446455
"""
447-
if not hasattr(self, "_model_matrix") or self._model_matrix is None:
448-
self._model_matrix = balance_util.model_matrix(
449-
self.df, add_na=True, return_type="one"
450-
)["model_matrix"]
451-
return self._model_matrix
456+
if formula is None:
457+
if not hasattr(self, "_model_matrix") or self._model_matrix is None:
458+
self._model_matrix = balance_util.model_matrix(
459+
self.df, add_na=True, return_type="one"
460+
)["model_matrix"]
461+
return self._model_matrix
462+
463+
return _verify_value_type(
464+
balance_util.model_matrix(
465+
self.df, add_na=True, return_type="one", formula=formula
466+
)["model_matrix"],
467+
pd.DataFrame,
468+
)
452469

453470
def _descriptive_stats(
454471
self: "BalanceDF",

tests/test_balancedf.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from balance.sample_class import Sample
2424
from balance.stats_and_plots import weighted_comparisons_stats
2525
from balance.testutil import BalanceTestCase, tempfile_path
26+
from balance.utils.model_matrix import model_matrix
27+
from patsy import PatsyError # pyre-ignore[21]
2628

2729

2830
class TestDataFactory:
@@ -1721,6 +1723,48 @@ def testBalanceDF_model_matrix(self) -> None:
17211723
},
17221724
)
17231725

1726+
def testBalanceDF_model_matrix_with_formula(self) -> None:
1727+
covars = s1.covars()
1728+
expected = model_matrix(
1729+
covars.df, add_na=True, return_type="one", formula="a + b"
1730+
)["model_matrix"]
1731+
result = covars.model_matrix(formula="a + b")
1732+
pd.testing.assert_frame_equal(result, expected)
1733+
1734+
def testBalanceDF_model_matrix_with_formula_list(self) -> None:
1735+
covars = s1.covars()
1736+
expected = model_matrix(
1737+
covars.df, add_na=True, return_type="one", formula=["a", "b"]
1738+
)["model_matrix"]
1739+
result = covars.model_matrix(formula=["a", "b"])
1740+
pd.testing.assert_frame_equal(result, expected)
1741+
1742+
def testBalanceDF_model_matrix_with_interaction_formula(self) -> None:
1743+
covars = s1.covars()
1744+
expected = model_matrix(
1745+
covars.df, add_na=True, return_type="one", formula="a * c"
1746+
)["model_matrix"]
1747+
result = covars.model_matrix(formula="a * c")
1748+
pd.testing.assert_frame_equal(result, expected)
1749+
1750+
def testBalanceDF_model_matrix_formula_does_not_affect_cache(self) -> None:
1751+
covars = s1.covars()
1752+
cached_before = covars.model_matrix()
1753+
formula_result = covars.model_matrix(formula="a")
1754+
cached_after = covars.model_matrix()
1755+
pd.testing.assert_frame_equal(
1756+
formula_result,
1757+
model_matrix(covars.df, add_na=True, return_type="one", formula="a")[
1758+
"model_matrix"
1759+
],
1760+
)
1761+
pd.testing.assert_frame_equal(cached_after, cached_before)
1762+
1763+
def testBalanceDF_model_matrix_with_invalid_formula(self) -> None:
1764+
covars = s1.covars()
1765+
with self.assertRaises(PatsyError):
1766+
covars.model_matrix(formula="missing_column + a")
1767+
17241768
def test_check_if_not_BalanceDF(self) -> None:
17251769
with self.assertRaisesRegex(ValueError, "number must be balancedf_class"):
17261770
BalanceDF._check_if_not_BalanceDF(

0 commit comments

Comments
 (0)