Skip to content

Commit 421d23e

Browse files
neuralsorcerermeta-codesync[bot]
authored andcommitted
Guard winsorized weights against percentile overshoot (#190)
Summary: Added post-winsorization clipping that uses original percentile bounds to prevent numerical overshoots when trimming weights. - Fixes #188 Pull Request resolved: #190 Differential Revision: D88137688 Pulled By: talgalili fbshipit-source-id: de62c0897021c1da03d466e7ca1dfa08c1067d3b
1 parent 779b9ed commit 421d23e

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

balance/adjustment.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import scipy
1818

1919
from balance import util as balance_util
20+
from balance.testutil import _verify_value_type
2021
from balance.weighting_methods import (
2122
adjust_null as balance_adjust_null,
2223
cbps as balance_cbps,
@@ -272,11 +273,21 @@ def trim_weights(
272273
else:
273274
lower_limit = upper_limit = percentile
274275

276+
# Keep the original requested percentiles for exact clipping bounds,
277+
# but validate/adjust separately for the winsorization call so at least
278+
# one value is affected at the requested edge.
279+
clip_limits = (
280+
None if (lower_limit is None or lower_limit == 0) else lower_limit,
281+
None if (upper_limit is None or upper_limit == 0) else upper_limit,
282+
)
275283
adjusted_limits = (
276284
_validate_limit(lower_limit, n_weights),
277285
_validate_limit(upper_limit, n_weights),
278286
)
279287

288+
# Preserve the pre-trim weights to calculate strict clipping bounds.
289+
original_weights_for_bounds = weights.copy()
290+
280291
weights = scipy.stats.mstats.winsorize(
281292
weights, limits=adjusted_limits, inplace=False
282293
)
@@ -291,6 +302,26 @@ def trim_weights(
291302
name=original_name,
292303
)
293304

305+
# Clip to the exact percentile bounds to avoid small numerical overshoots
306+
# from scipy.stats.mstats.winsorize on certain inputs.
307+
lower_bound = (
308+
None
309+
if clip_limits[0] is None
310+
else np.quantile(
311+
original_weights_for_bounds, clip_limits[0], method="lower"
312+
)
313+
)
314+
upper_bound = (
315+
None
316+
if clip_limits[1] is None
317+
else np.quantile(
318+
original_weights_for_bounds,
319+
1 - _verify_value_type(clip_limits[1]),
320+
method="lower",
321+
)
322+
)
323+
weights = weights.clip(lower=lower_bound, upper=upper_bound)
324+
294325
if keep_sum_of_weights:
295326
weights = weights / np.mean(weights) * original_mean
296327

0 commit comments

Comments
 (0)