|
27 | 27 | ) |
28 | 28 | from balance.typing import FilePathOrBuffer |
29 | 29 | from balance.util import find_items_index_in_list, get_items_from_list_via_indices |
| 30 | +from balance.utils.input_validation import _verify_value_type |
30 | 31 | from IPython.lib.display import FileLink |
31 | 32 | from plotly.graph_objs import Figure |
32 | 33 |
|
@@ -392,18 +393,19 @@ def _call_on_linked( |
392 | 393 | # if v is not None and k not in exclude |
393 | 394 | # ) |
394 | 395 |
|
395 | | - # TODO: add the ability to pass formula argument to model_matrix |
396 | | - # but in which case - notice that we'd want the ability to track |
397 | | - # which object is stored in _model_matrix (and to run it over) |
398 | | - # Also, the output may sometimes no longer only be pd.DataFrame |
399 | | - # so such work will require update the type hinting here. |
400 | | - def model_matrix(self: "BalanceDF") -> pd.DataFrame: |
| 396 | + def model_matrix( |
| 397 | + self: "BalanceDF", formula: str | list[str] | None = None |
| 398 | + ) -> pd.DataFrame: |
401 | 399 | """Return a model_matrix version of the df inside the BalanceDF object using balance_util.model_matrix |
402 | 400 |
|
403 | 401 | This can be used to turn all character columns into a one hot encoding columns. |
404 | 402 |
|
405 | 403 | Args: |
406 | 404 | self (BalanceDF): Object |
| 405 | + formula (str | list[str] | None, optional): Optional formula string (or list of |
| 406 | + formula strings) to pass to :func:`balance_util.model_matrix`. When |
| 407 | + provided, the model matrix is computed on demand for the formula and |
| 408 | + not cached on the object. Defaults to None. |
407 | 409 |
|
408 | 410 | Returns: |
409 | 411 | pd.DataFrame: The output from :func:`balance_util.model_matrix` |
@@ -443,12 +445,27 @@ def model_matrix(self: "BalanceDF") -> pd.DataFrame: |
443 | 445 | # 1 2.0 8.0 0.0 0.0 1.0 0.0 |
444 | 446 | # 2 3.0 2.0 0.0 0.0 0.0 1.0 |
445 | 447 | # 3 1.0 -42.0 1.0 0.0 0.0 0.0 |
| 448 | +
|
| 449 | + print(s1.covars().model_matrix(formula="a + b")) |
| 450 | + # a b |
| 451 | + # 0 1.0 -42.0 |
| 452 | + # 1 2.0 8.0 |
| 453 | + # 2 3.0 2.0 |
| 454 | + # 3 1.0 -42.0 |
446 | 455 | """ |
447 | | - if not hasattr(self, "_model_matrix") or self._model_matrix is None: |
448 | | - self._model_matrix = balance_util.model_matrix( |
449 | | - self.df, add_na=True, return_type="one" |
450 | | - )["model_matrix"] |
451 | | - return self._model_matrix |
| 456 | + if formula is None: |
| 457 | + if not hasattr(self, "_model_matrix") or self._model_matrix is None: |
| 458 | + self._model_matrix = balance_util.model_matrix( |
| 459 | + self.df, add_na=True, return_type="one" |
| 460 | + )["model_matrix"] |
| 461 | + return self._model_matrix |
| 462 | + |
| 463 | + return _verify_value_type( |
| 464 | + balance_util.model_matrix( |
| 465 | + self.df, add_na=True, return_type="one", formula=formula |
| 466 | + )["model_matrix"], |
| 467 | + pd.DataFrame, |
| 468 | + ) |
452 | 469 |
|
453 | 470 | def _descriptive_stats( |
454 | 471 | self: "BalanceDF", |
|
0 commit comments