Skip to content

Commit 21d657a

Browse files
neuralsorcerermeta-codesync[bot]
authored andcommitted
Optimize rake by replacing ipfn with vectorized IPF (#135)
Summary: Added a vectorized `_run_ipf_numpy` helper that mirrors the original ipfn behaviour while avoiding the package dependency. Reworked the raking workflow to feed the new solver, rebuild the cell-weight mapping via DataFrame joins, and return the historical `rake_weight` series shape. Pull Request resolved: #135 Reviewed By: wesleytlee Differential Revision: D86672684 Pulled By: talgalili fbshipit-source-id: 578883799e1d15ca8a868964a4a7843049bf7fa2
1 parent 5714d1c commit 21d657a

File tree

5 files changed

+305
-48
lines changed

5 files changed

+305
-48
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ REQUIRES = [
4242
"plotly",
4343
"matplotlib",
4444
"statsmodels",
45-
"ipfn",
4645
"session-info",
4746
]
4847
```

balance/weighting_methods/rake.py

Lines changed: 108 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,104 @@
1515
from fractions import Fraction
1616

1717
from functools import reduce
18-
from typing import Callable, Dict, List, Union
18+
from typing import Callable, Dict, List, Tuple, Union
1919

2020
import numpy as np
2121
import pandas as pd
2222

2323
from balance import adjustment as balance_adjustment, util as balance_util
24-
from ipfn import ipfn
2524

2625
logger = logging.getLogger(__package__)
2726

2827

2928
# TODO: Add options for only marginal distributions input
29+
def _run_ipf_numpy(
30+
original: np.ndarray,
31+
target_margins: List[np.ndarray],
32+
convergence_rate: float,
33+
max_iteration: int,
34+
rate_tolerance: float,
35+
) -> Tuple[np.ndarray, int, pd.DataFrame]:
36+
"""Run iterative proportional fitting on a NumPy array.
37+
38+
This reimplements the minimal subset of the :mod:`ipfn` package that is
39+
required for balance's usage. The original implementation spends most of
40+
its time looping in pure Python over every slice of the contingency table,
41+
which is prohibitively slow for the high-cardinality problems we test
42+
against. The logic here mirrors the algorithm used by ``ipfn.ipfn`` but
43+
applies the adjustments in a vectorised manner, yielding identical
44+
numerical results with a fraction of the runtime.
45+
46+
The caller is expected to pass ``target_margins`` that correspond to
47+
single-axis marginals (which is how :func:`rake` constructs the inputs).
48+
"""
49+
50+
if original.ndim == 0:
51+
raise ValueError("`original` must have at least one dimension")
52+
53+
table = np.asarray(original, dtype=np.float64)
54+
margins = [np.asarray(margin, dtype=np.float64) for margin in target_margins]
55+
56+
# Pre-compute shapes and axes that are repeatedly required during the
57+
# iterative updates. Each entry in ``axis_shapes`` represents how a
58+
# one-dimensional scaling factor should be reshaped in order to broadcast
59+
# along the appropriate axis of ``table``.
60+
axis_shapes: List[Tuple[int, ...]] = []
61+
sum_axes: List[Tuple[int, ...]] = []
62+
for axis in range(table.ndim):
63+
shape = [1] * table.ndim
64+
shape[axis] = table.shape[axis]
65+
axis_shapes.append(tuple(shape))
66+
sum_axes.append(tuple(i for i in range(table.ndim) if i != axis))
67+
68+
conv = np.inf
69+
old_conv = -np.inf
70+
conv_history: List[float] = []
71+
iteration = 0
72+
73+
while (
74+
iteration <= max_iteration
75+
and conv > convergence_rate
76+
and abs(conv - old_conv) > rate_tolerance
77+
):
78+
old_conv = conv
79+
80+
# Sequentially update the table for each marginal. Because the
81+
# marginals correspond to single axes we can compute all scaling
82+
# factors at once, avoiding the expensive Python loops present in the
83+
# reference implementation.
84+
for axis, margin in enumerate(margins):
85+
current = table.sum(axis=sum_axes[axis])
86+
factors = np.ones_like(margin, dtype=np.float64)
87+
np.divide(margin, current, out=factors, where=current != 0)
88+
table *= factors.reshape(axis_shapes[axis])
89+
90+
# Measure convergence using the same criterion as ``ipfn.ipfn``. The
91+
# implementation there keeps the maximum absolute proportional
92+
# difference while naturally ignoring NaNs (which arise for 0/0). We
93+
# match that behaviour by treating NaNs as zero deviation.
94+
conv = 0.0
95+
for axis, margin in enumerate(margins):
96+
current = table.sum(axis=sum_axes[axis])
97+
with np.errstate(divide="ignore", invalid="ignore"):
98+
diff = np.abs(np.divide(current, margin) - 1.0)
99+
current_conv = float(np.nanmax(diff)) if diff.size else 0.0
100+
if math.isnan(current_conv):
101+
current_conv = 0.0
102+
if current_conv > conv:
103+
conv = current_conv
104+
105+
conv_history.append(conv)
106+
iteration += 1
107+
108+
converged = int(iteration <= max_iteration)
109+
iterations_df = pd.DataFrame(
110+
{"iteration": range(len(conv_history)), "conv": conv_history}
111+
).set_index("iteration")
112+
113+
return table, converged, iterations_df
114+
115+
30116
def rake(
31117
sample_df: pd.DataFrame,
32118
sample_weights: pd.Series,
@@ -179,24 +265,11 @@ def rake(
179265
# Calculate {# covariates}-dimensional array representation of the sample
180266
# for the ipfn algorithm
181267

182-
# Create a multi-index DataFrame with all possible combinations
268+
grouped_sample_series = sample_df.groupby(alphabetized_variables)["weight"].sum()
183269
index = pd.MultiIndex.from_product(categories, names=alphabetized_variables)
184-
full_df = pd.DataFrame(index=index).reset_index()
185-
186-
# Group by covariates and sum weights
187-
grouped_sample = (
188-
sample_df.groupby(alphabetized_variables)["weight"].sum().reset_index()
189-
)
190-
191-
# Merge to ensure all combinations are present (fill missing with 0)
192-
merged = (
193-
pd.merge(full_df, grouped_sample, on=alphabetized_variables, how="left")
194-
.fillna(0)
195-
.infer_objects(copy=False)
196-
)
197-
198-
# Reshape to n-dimensional array
199-
m_sample = merged["weight"].values.reshape([len(c) for c in categories])
270+
grouped_sample_full = grouped_sample_series.reindex(index, fill_value=0)
271+
m_sample = grouped_sample_full.to_numpy().reshape([len(c) for c in categories])
272+
m_fit_input = m_sample.copy()
200273

201274
# Calculate target margins for ipfn
202275
target_margins = []
@@ -208,7 +281,6 @@ def rake(
208281
)
209282
sums = sums.reindex(cats, fill_value=0)
210283
target_margins.append(sums.values)
211-
dimensions = [[i] for i in range(len(alphabetized_variables))]
212284

213285
logger.debug(
214286
"Raking algorithm running following settings: "
@@ -219,16 +291,13 @@ def rake(
219291
# for that specific set of covariates
220292
# no longer uses the dataframe version of the ipfn algorithm
221293
# due to incompatability with latest Python versions
222-
ipfn_obj = ipfn.ipfn(
223-
original=m_sample,
224-
aggregates=target_margins,
225-
dimensions=dimensions,
226-
convergence_rate=convergence_rate,
227-
max_iteration=max_iteration,
228-
verbose=2,
229-
rate_tolerance=rate_tolerance,
294+
m_fit, converged, iterations = _run_ipf_numpy(
295+
m_fit_input,
296+
target_margins,
297+
convergence_rate,
298+
max_iteration,
299+
rate_tolerance,
230300
)
231-
m_fit, converged, iterations = ipfn_obj.iteration()
232301

233302
logger.debug(
234303
f"Raking algorithm terminated with following convergence: {converged}; "
@@ -238,14 +307,9 @@ def rake(
238307
if not converged:
239308
logger.warning("Maximum iterations reached, convergence was not achieved")
240309

241-
# Convert array representation of the weighted sample into a dataframe
242-
# Generate all possible combinations of categories (cartesian product)
243310
combos = list(itertools.product(*categories))
244-
# Flatten the array to match the order of combos
245-
weights = m_fit.flatten()
246-
# Build the DataFrame
247311
fit = pd.DataFrame(combos, columns=alphabetized_variables)
248-
fit["rake_weight"] = weights
312+
fit["rake_weight"] = m_fit.flatten()
249313

250314
raked = pd.merge(
251315
sample_df.reset_index(),
@@ -254,23 +318,24 @@ def rake(
254318
on=alphabetized_variables,
255319
)
256320

257-
# Merge back to original sample weights
258321
raked_rescaled = pd.merge(
259322
raked,
260-
(grouped_sample.rename(columns={"weight": "total_survey_weight"})),
323+
grouped_sample_series.reset_index().rename(
324+
columns={"weight": "total_survey_weight"}
325+
),
261326
how="left",
262327
on=alphabetized_variables,
263-
).set_index("index") # important if dropping rows due to NA
328+
).set_index("index")
264329

265-
# use above merge to give each individual sample its proportion of the
266-
# cell's total weight
267330
raked_rescaled["rake_weight"] = (
268331
raked_rescaled["rake_weight"] / raked_rescaled["total_survey_weight"]
269332
)
270-
# rescale weights to sum to target_sum_weights (sum of initial target weights)
333+
271334
w = (
272-
raked_rescaled["rake_weight"] / raked_rescaled["rake_weight"].sum()
273-
) * target_sum_weights
335+
raked_rescaled["rake_weight"]
336+
/ raked_rescaled["rake_weight"].sum()
337+
* target_sum_weights
338+
).rename("rake_weight")
274339
return {
275340
"weight": w,
276341
"model": {

benchmarks/benchmark_ipfn.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from __future__ import annotations
9+
10+
import copy
11+
import time
12+
from typing import Callable, Sequence
13+
14+
import numpy as np
15+
16+
from balance.weighting_methods.rake import _run_ipf_numpy
17+
18+
try:
19+
from ipfn import ipfn as ipfn_module
20+
except ImportError: # pragma: no cover - optional dependency for benchmarking
21+
ipfn_module = None
22+
23+
24+
def _build_problem(seed: int = 0) -> tuple[np.ndarray, list[np.ndarray]]:
25+
"""Construct a moderately sized contingency table and consistent margins."""
26+
27+
rng = np.random.default_rng(seed)
28+
shape = (8, 10, 12)
29+
table = rng.uniform(0.1, 5.0, size=shape)
30+
31+
margins: list[np.ndarray] = []
32+
for axis in range(table.ndim):
33+
margin = table.sum(axis=tuple(i for i in range(table.ndim) if i != axis))
34+
# Introduce mild perturbations while keeping totals consistent.
35+
noise = rng.uniform(0.95, 1.05, size=margin.shape)
36+
margin = margin * noise
37+
margin *= table.sum() / margin.sum()
38+
margins.append(margin)
39+
return table, margins
40+
41+
42+
def _timeit(func: Callable[[], np.ndarray], repeat: int = 7) -> float:
43+
start = time.perf_counter()
44+
for _ in range(repeat):
45+
func()
46+
end = time.perf_counter()
47+
return (end - start) / repeat
48+
49+
50+
def _run_ipfn_lib(original: np.ndarray, margins: Sequence[np.ndarray]) -> np.ndarray:
51+
if ipfn_module is None:
52+
raise RuntimeError(
53+
"The `ipfn` package is not installed. Install it with `pip install ipfn`"
54+
" to run the lib benchmark."
55+
)
56+
57+
dims = [[axis] for axis in range(original.ndim)]
58+
solver = ipfn_module.ipfn(
59+
copy.deepcopy(original),
60+
[np.array(m, copy=True) for m in margins],
61+
dims,
62+
convergence_rate=5e-7,
63+
max_iteration=1000,
64+
rate_tolerance=0.0,
65+
verbose=0,
66+
)
67+
return solver.iteration()
68+
69+
70+
def _run_ipfn_numpy(original: np.ndarray, margins: Sequence[np.ndarray]) -> np.ndarray:
71+
solution, _, _ = _run_ipf_numpy(
72+
np.array(original, copy=True),
73+
[np.array(m, copy=True) for m in margins],
74+
convergence_rate=5e-7,
75+
max_iteration=1000,
76+
rate_tolerance=0.0,
77+
)
78+
return solution
79+
80+
81+
def main() -> None:
82+
original, margins = _build_problem()
83+
84+
numpy_solution = _run_ipfn_numpy(original, margins)
85+
if ipfn_module is not None:
86+
lib_solution = _run_ipfn_lib(original, margins)
87+
np.testing.assert_allclose(lib_solution, numpy_solution, atol=1e-6)
88+
89+
numpy_timing = _timeit(lambda: _run_ipfn_numpy(original, margins))
90+
91+
if ipfn_module is None:
92+
print("ipfn package not installed; only NumPy implementation timing available.")
93+
print(f"NumPy IPF solver: {numpy_timing * 1000:.2f} ms per run")
94+
return
95+
96+
lib_timing = _timeit(lambda: _run_ipfn_lib(original, margins))
97+
98+
print("Iterative proportional fitting benchmark (5-run average)")
99+
print(f"lib ipfn.ipfn solver: {lib_timing * 1000:.2f} ms per run")
100+
print(f"NumPy _run_ipf_numpy solver: {numpy_timing * 1000:.2f} ms per run")
101+
print(f"Speed-up: {lib_timing / numpy_timing:.2f}x faster")
102+
103+
104+
if __name__ == "__main__": # pragma: no cover - convenience entry point
105+
main()

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
"plotly",
2626
"matplotlib",
2727
"statsmodels",
28-
"ipfn",
2928
"session-info",
3029
]
3130

0 commit comments

Comments
 (0)