Skip to content

Commit 722faab

Browse files
talgalilimeta-codesync[bot]
authored andcommitted
Modernize type hints to use PEP 604 union syntax in balance package ( | ) (#157)
Summary: Pull Request resolved: #157 Updated type annotations across the balance package to use the newer PEP 604 union syntax (`X | Y`) instead of the older `typing.Union` and `typing.Optional` syntax. This modernization improves code readability and aligns with Python 3.10+ typing conventions. Key changes: - Replaced `Union[X, Y]` with `X | Y` - Replaced `Optional[X]` with `X | None` - Updated `from __future__ import` statements to use `annotations` instead of the older `absolute_import, division, print_function, unicode_literals` - Removed unnecessary `Union` and `Optional` imports from `typing` All changes are backward compatible with Python 3.9 as `from __future__ import annotations` enables deferred evaluation of type hints, allowing the new syntax to work properly. This modernization affects 11 files across the balance package including core modules like `adjustment.py`, `balancedf_class.py`, `sample_class.py`, and various stats and weighting methods. Reviewed By: wesleytlee Differential Revision: D87614454 fbshipit-source-id: 159b825ef95b419840365131efaa40c962f57bc6
1 parent 9f06259 commit 722faab

File tree

13 files changed

+261
-335
lines changed

13 files changed

+261
-335
lines changed

CHANGELOG.md

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# 0.12.x (2025-11-16)
1+
# 0.12.x (2025-11-21)
22

33
> TODO: update 0.12.x to 0.13.0 before release.
44
@@ -65,6 +65,25 @@
6565
(`weighting_methods/cbps.py`, `weighting_methods/ipw.py`,
6666
`weighting_methods/poststratify.py`, `weighting_methods/rake.py`), and
6767
datasets module (`datasets/__init__.py`)
68+
- **Modernized type hints to PEP 604 syntax**: Updated all type annotations
69+
across 11 files to use the newer PEP 604 union syntax (`X | Y` instead of
70+
`Union[X, Y]` and `X | None` instead of `Optional[X]`), improving code
71+
readability and aligning with Python 3.10+ typing conventions. Updated
72+
`from __future__ import` statements to use `annotations` instead of the
73+
older `absolute_import, division, print_function, unicode_literals`.
74+
Removed unnecessary `Union` and `Optional` imports from `typing`. Files
75+
updated: `__init__.py`, `adjustment.py`, `balancedf_class.py`, `cli.py`,
76+
`datasets/__init__.py`, `sample_class.py`,
77+
`stats_and_plots/weighted_comparisons_stats.py`,
78+
`stats_and_plots/weighted_stats.py`, `stats_and_plots/weights_stats.py`,
79+
`util.py`, `weighting_methods/ipw.py`.
80+
- **Important compatibility note**:
81+
Type alias definitions in `typing.py` retain `Union` syntax for Python 3.9
82+
compatibility, as the `|` operator for type aliases only works at runtime
83+
in Python 3.10+. Added comprehensive inline documentation explaining this
84+
limitation and the distinction between type annotations (which support `|`
85+
with `from __future__ import annotations`) and type alias assignments
86+
(which require `Union` for runtime evaluation in Python 3.9).
6887
- Fixed missing `Any` import in `weighted_comparisons_plots.py` to resolve
6988
pyre-fixme[10] error
7089
- Added comprehensive type annotations for previously untyped parameters and
@@ -79,6 +98,13 @@
7998
- Improved `quantize` function: preserves column ordering and replaces
8099
assertions with proper TypeError exceptions
81100
([#133](https://github.com/facebookresearch/balance/pull/133)).
101+
- **Statistical Functions**
102+
- **Fixed division by zero in `asmd_improvement()`**: Added safety check to
103+
prevent RuntimeWarning when `asmd_mean_before` is zero or very close to zero
104+
(< 1e-10). The function now returns `0.0` (representing 0% improvement) when
105+
the sample was already perfectly matched to the target before adjustment,
106+
which is the semantically correct result. This eliminates the "invalid value
107+
encountered in scalar divide" warning that appeared in test runs.
82108
- **Weighting Methods**
83109
- `rake()` and `poststratify()` now honour `weight_trimming_mean_ratio` and
84110
`weight_trimming_percentile`, trimming and renormalising weights through the

balance/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
# pyre-strict
77

8+
from __future__ import annotations
9+
810
import logging
9-
from typing import Optional
1011

1112
from balance.balancedf_class import ( # noqa
1213
BalanceCovarsDF, # noqa
@@ -41,7 +42,7 @@ def help() -> None:
4142

4243

4344
def setup_logging(
44-
logger_name: Optional[str] = __package__,
45+
logger_name: str | None = __package__,
4546
level: str = "INFO",
4647
removeHandler: bool = True,
4748
) -> logging.Logger:

balance/adjustment.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
# pyre-strict
77

8-
from __future__ import absolute_import, division, print_function, unicode_literals
8+
from __future__ import annotations
99

1010
import logging
1111

12-
from typing import Any, Callable, Dict, List, Literal, Tuple, Union
12+
from typing import Any, Callable, Dict, Literal, Tuple
1313

1414
import numpy as np
1515
import numpy.typing as npt
@@ -38,9 +38,7 @@
3838
}
3939

4040

41-
def _validate_limit(
42-
limit: Union[float, int, None], n_weights: int
43-
) -> Union[float, None]:
41+
def _validate_limit(limit: float | int | None, n_weights: int) -> float | None:
4442
"""Validate and adjust a percentile limit for use with scipy.stats.mstats.winsorize.
4543
4644
This function prepares percentile limits for winsorization by:
@@ -88,13 +86,13 @@ def _validate_limit(
8886

8987

9088
def trim_weights(
91-
weights: Union[pd.Series, npt.NDArray],
89+
weights: pd.Series | npt.NDArray,
9290
# TODO: add support to more types of input weights? (e.g. list? other?)
93-
weight_trimming_mean_ratio: Union[float, int, None] = None,
94-
weight_trimming_percentile: Union[float, None] = None,
91+
weight_trimming_mean_ratio: float | int | None = None,
92+
weight_trimming_percentile: float | None = None,
9593
verbose: bool = False,
9694
keep_sum_of_weights: bool = True,
97-
target_sum_weights: Union[float, int, np.floating, None] = None,
95+
target_sum_weights: float | int | np.floating | None = None,
9896
) -> pd.Series:
9997
"""Trim extreme weights using mean ratio clipping or percentile-based winsorization.
10098
@@ -132,22 +130,22 @@ def trim_weights(
132130
desired total.
133131
134132
Args:
135-
weights (Union[pd.Series, np.ndarray]): Weights to trim. np.ndarray will be
133+
weights (pd.Series | np.ndarray): Weights to trim. np.ndarray will be
136134
converted to pd.Series internally.
137-
weight_trimming_mean_ratio (Union[float, int], optional): Ratio for upper bound
135+
weight_trimming_mean_ratio (float | int | None, optional): Ratio for upper bound
138136
clipping as mean(weights) * ratio. Mutually exclusive with
139137
weight_trimming_percentile. Defaults to None.
140-
weight_trimming_percentile (Union[float, Tuple[float, float]], optional):
138+
weight_trimming_percentile (float | tuple[float, float] | None, optional):
141139
Percentile limits for winsorization. Value(s) must be between 0 and 1.
142140
- Single float: Symmetric winsorization on both tails
143-
- Tuple[float, float]: (lower_percentile, upper_percentile) for
141+
- tuple[float, float]: (lower_percentile, upper_percentile) for
144142
independent control of each tail
145143
Mutually exclusive with weight_trimming_mean_ratio. Defaults to None.
146144
verbose (bool, optional): Whether to log details about the trimming process.
147145
Defaults to False.
148146
keep_sum_of_weights (bool, optional): Whether to rescale weights after trimming
149147
to preserve the original sum of weights. Defaults to True.
150-
target_sum_weights (Union[float, int, np.floating, None], optional): If
148+
target_sum_weights (float | int | np.floating | None, optional): If
151149
provided, rescale the trimmed weights so their sum equals this
152150
target. ``None`` (default) leaves the post-trimming sum unchanged.
153151
@@ -309,14 +307,14 @@ def trim_weights(
309307

310308

311309
def default_transformations(
312-
dfs: Union[Tuple[pd.DataFrame, ...], List[pd.DataFrame]],
310+
dfs: tuple[pd.DataFrame, ...] | list[pd.DataFrame],
313311
) -> Dict[str, Callable[..., Any]]:
314312
"""
315313
Apply default transformations to dfs, i.e.
316314
quantize to numeric columns and fct_lump to non-numeric and boolean
317315
318316
Args:
319-
dfs (Union[Tuple[pd.DataFrame, ...], List[pd.DataFrame]]): A list or tuple of dataframes
317+
dfs (tuple[pd.DataFrame, ...] | list[pd.DataFrame]): A list or tuple of dataframes
320318
321319
Returns:
322320
Dict[str, Callable]: Dict of transformations
@@ -339,7 +337,7 @@ def default_transformations(
339337

340338
def apply_transformations(
341339
dfs: Tuple[pd.DataFrame, ...],
342-
transformations: Union[Dict[str, Callable[..., Any]], str, None],
340+
transformations: Dict[str, Callable[..., Any]] | str | None,
343341
drop: bool = True,
344342
) -> Tuple[pd.DataFrame, ...]:
345343
"""Apply the transformations specified in transformations to all of the dfs
@@ -357,7 +355,7 @@ def apply_transformations(
357355
358356
Args:
359357
dfs (Tuple[pd.DataFrame, ...]): The DataFrames on which to operate
360-
transformations (Union[Dict[str, Callable], str, None]): Mapping from column name to function to apply.
358+
transformations (Dict[str, Callable[..., Any]] | str | None): Mapping from column name to function to apply.
361359
Transformations of existing columns should be specified as functions
362360
of those columns (e.g. `lambda x: x*2`), whereas additions of new
363361
columns should be specified as functions of the DataFrame

balance/balancedf_class.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55

66
# pyre-strict
77

8+
from __future__ import annotations
9+
810
import logging
9-
from typing import Any, Dict, Literal, Optional, Tuple, Union
11+
from typing import Any, Dict, Literal, Tuple
1012

1113
import numpy as np
1214
import numpy.typing as npt
@@ -106,14 +108,14 @@ def _sample(self: "BalanceDF") -> "Sample":
106108
@property
107109
def _weights(
108110
self: "BalanceDF",
109-
) -> Optional[pd.DataFrame]:
111+
) -> pd.DataFrame | None:
110112
"""Access the weight_column in __sample.
111113
112114
Args:
113115
self (BalanceDF): Object
114116
115117
Returns:
116-
Optional[pd.DataFrame]: The weights (with no column name)
118+
pd.DataFrame | None: The weights (with no column name)
117119
"""
118120
w = self._sample.weight_column
119121
return w.rename(None)
@@ -123,13 +125,11 @@ def _BalanceDF_child_from_linked_samples(
123125
self: "BalanceDF",
124126
) -> Dict[
125127
str,
126-
Union[
127-
"BalanceDF",
128-
"BalanceCovarsDF",
129-
"BalanceWeightsDF",
130-
"BalanceOutcomesDF",
131-
None,
132-
],
128+
"BalanceDF"
129+
| "BalanceCovarsDF"
130+
| "BalanceWeightsDF"
131+
| "BalanceOutcomesDF"
132+
| None,
133133
]:
134134
"""Returns a dict with self and the same type of BalanceDF_child when created from the linked samples.
135135
@@ -270,13 +270,11 @@ def _BalanceDF_child_from_linked_samples(
270270
BalanceDF_child_method = self.__name
271271
d: Dict[
272272
str,
273-
Union[
274-
"BalanceDF",
275-
"BalanceCovarsDF",
276-
"BalanceWeightsDF",
277-
"BalanceOutcomesDF",
278-
None,
279-
],
273+
"BalanceDF"
274+
| "BalanceCovarsDF"
275+
| "BalanceWeightsDF"
276+
| "BalanceOutcomesDF"
277+
| None,
280278
] = {"self": self}
281279
d.update(
282280
{
@@ -492,7 +490,7 @@ def _descriptive_stats(
492490
)
493491
return wdf
494492

495-
def to_download(self: "BalanceDF", tempdir: Optional[str] = None) -> FileLink:
493+
def to_download(self: "BalanceDF", tempdir: str | None = None) -> FileLink:
496494
"""Creates a downloadable link of the DataFrame, with ids, of the BalanceDF object.
497495
498496
File name starts with tmp_balance_out_, and some random file name (using :func:`uuid.uuid4`).
@@ -1012,7 +1010,7 @@ def mean_with_ci(
10121010
# NOTE: Summary could return also an str in case it is overridden in other children's methods.
10131011
def summary(
10141012
self: "BalanceDF", on_linked_samples: bool = True
1015-
) -> Union[pd.DataFrame, str]:
1013+
) -> pd.DataFrame | str:
10161014
"""
10171015
Returns a summary of the BalanceDF object.
10181016
@@ -1038,14 +1036,14 @@ def summary(
10381036

10391037
def _get_df_and_weights(
10401038
self: "BalanceDF",
1041-
) -> Tuple[pd.DataFrame, Optional[npt.NDArray]]:
1039+
) -> Tuple[pd.DataFrame, npt.NDArray | None]:
10421040
"""Extract covars df (after using model_matrix) and weights from a BalanceDF object.
10431041
10441042
Args:
10451043
self (BalanceDF): Object
10461044
10471045
Returns:
1048-
Tuple[pd.DataFrame, Optional[np.ndarray]]:
1046+
Tuple[pd.DataFrame, np.ndarray | None]:
10491047
A pd.DataFrame output from running :func:`model_matrix`, and
10501048
A np.ndarray of weights from :func:`_weights`, or just None (if there are no weights).
10511049
"""
@@ -1119,7 +1117,7 @@ def _asmd_BalanceDF(
11191117
def asmd(
11201118
self: "BalanceDF",
11211119
on_linked_samples: bool = True,
1122-
target: Optional["BalanceDF"] = None,
1120+
target: "BalanceDF" | None = None,
11231121
aggregate_by_main_covar: bool = False,
11241122
**kwargs: Any,
11251123
) -> pd.DataFrame:
@@ -1246,8 +1244,8 @@ def asmd(
12461244

12471245
def asmd_improvement(
12481246
self: "BalanceDF",
1249-
unadjusted: Optional["BalanceDF"] = None,
1250-
target: Optional["BalanceDF"] = None,
1247+
unadjusted: "BalanceDF" | None = None,
1248+
target: "BalanceDF" | None = None,
12511249
) -> np.float64:
12521250
"""Calculates the improvement in mean(asmd) from before to after applying some weight adjustment.
12531251
@@ -1374,10 +1372,10 @@ def _df_with_ids(self: "BalanceDF") -> pd.DataFrame:
13741372

13751373
def to_csv(
13761374
self: "BalanceDF",
1377-
path_or_buf: Optional[FilePathOrBuffer] = None,
1375+
path_or_buf: FilePathOrBuffer | None = None,
13781376
*args: Any,
13791377
**kwargs: Any,
1380-
) -> Optional[str]:
1378+
) -> str | None:
13811379
"""Write df with ids from BalanceDF to a comma-separated values (csv) file.
13821380
13831381
Uses :func:`pd.DataFrame.to_csv`.
@@ -1414,9 +1412,9 @@ def __init__(self: "BalanceOutcomesDF", sample: Sample) -> None:
14141412
# this will also require to update _relative_response_rates a bit.
14151413
def relative_response_rates(
14161414
self: "BalanceOutcomesDF",
1417-
target: Union[bool, pd.DataFrame] = False,
1415+
target: bool | pd.DataFrame = False,
14181416
per_column: bool = False,
1419-
) -> Optional[pd.DataFrame]:
1417+
) -> pd.DataFrame | None:
14201418
"""Produces a summary table of number of responses and proportion of completed responses.
14211419
14221420
See :func:`general_stats.relative_response_rates`.
@@ -1513,7 +1511,7 @@ def relative_response_rates(
15131511
self.df, df_target, per_column=per_column
15141512
)
15151513

1516-
def target_response_rates(self: "BalanceOutcomesDF") -> Optional[pd.DataFrame]:
1514+
def target_response_rates(self: "BalanceOutcomesDF") -> pd.DataFrame | None:
15171515
"""Calculates relative_response_rates for the target in a Sample object.
15181516
15191517
See :func:`general_stats.relative_response_rates`.
@@ -1569,7 +1567,7 @@ def target_response_rates(self: "BalanceOutcomesDF") -> Optional[pd.DataFrame]:
15691567
# The BalanceDF.summary method only returns a DataFrame. So it's a question
15701568
# what is the best way to structure this more generally.
15711569
def summary(
1572-
self: "BalanceOutcomesDF", on_linked_samples: Optional[bool] = None
1570+
self: "BalanceOutcomesDF", on_linked_samples: bool | None = None
15731571
) -> str:
15741572
"""Produces summary printable string of a BalanceOutcomesDF object.
15751573
@@ -1831,8 +1829,8 @@ def _weights(self: "BalanceWeightsDF") -> None:
18311829

18321830
def trim(
18331831
self: "BalanceWeightsDF",
1834-
ratio: Optional[Union[float, int]] = None,
1835-
percentile: Optional[float] = None,
1832+
ratio: float | int | None = None,
1833+
percentile: float | None = None,
18361834
keep_sum_of_weights: bool = True,
18371835
) -> None:
18381836
"""Trim weights in the sample object.
@@ -1859,7 +1857,7 @@ def trim(
18591857
)
18601858

18611859
def summary(
1862-
self: "BalanceWeightsDF", on_linked_samples: Optional[bool] = None
1860+
self: "BalanceWeightsDF", on_linked_samples: bool | None = None
18631861
) -> pd.DataFrame:
18641862
"""
18651863
Generates a summary of a BalanceWeightsDF object.

0 commit comments

Comments
 (0)