-
-
Notifications
You must be signed in to change notification settings - Fork 142
Description
Hi,
I'm using Bambi for online predictions on a rather large model. The introduction of sparse matrices has made a HUGE improvement for memory usage. I have however encountered 2 performance issues where predict() takes about 8 seconds for 1 row, where I'm using it like:
model.predict(
self.idata,
kind="response",
data=data,
sample_new_groups=False,
inplace=False,
random_seed=1337,
)I've managed to bring it down to milliseconds by doing 2 things:
- adding the cftime library (it seems to check if this module exists hundreds of times if you don't have it) - Cuts about 40% on the total predict call
- This is the biiiig one:
When using Model.predict() repeatedly (e.g., in a web service), we noticed predictions taking ~6-8 seconds each. Profiling revealed that the bottleneck is not the actual inference but xarray's to_stacked_array() operation in DistributionalComponent.predict_common:
b = posterior[X_terms].to_stacked_array("variables", to_stack_dims)
For models with many terms (ours has ~80), this operation is expensive because it iterates over each term, calling expand_dims, stack, and concat. In our case, this single line accounts for ~4 seconds per prediction.
To my understanding, the stacked posterior b only depends on posterior, X_terms, and to_stack_dims, none of which change between predictions for a given model. Only the design matrix X changes based on new input data.
Proposed solution: Cache the result of to_stacked_array() on the component instance:
cache_key = (tuple(X_terms), to_stack_dims)
if not hasattr(self, "_stacked_posterior_cache"):
self._stacked_posterior_cache = {}
if cache_key not in self._stacked_posterior_cache:
self._stacked_posterior_cache[cache_key] = posterior[X_terms].to_stacked_array(
"__variables__", to_stack_dims
)
b = self._stacked_posterior_cache[cache_key]This makes the first prediction take the normal ~4s, but subsequent predictions become nearly instant. For serving models in production, this is a significant improvement.
For now we're patching the method, but it would be really cool if Bambi could incorporate such a change to help make online predictions work at scale out the box. The main use case is you're doing 1 row predict at a time, but thousands, so batching is not possible as it comes from a web request.
_original_predict_common = DistributionalComponent.predict_common
def _patched_predict_common(
self, posterior, data, in_sample, to_stack_dims, design_matrix_dims, hsgp_dict
):
from bambi.utils import get_aliased_name
x_offsets = []
linear_predictor = 0
response_dim = design_matrix_dims[0]
if in_sample:
X = self.design.common.design_matrix
else:
X = self.design.common.evaluate_new_data(data).design_matrix
# Add offset columns to their own design matrix and remove then from common matrix
for term in self.offset_terms:
term_slice = self.design.common.slices[term]
x_offsets.append(X[:, term_slice])
X = np.delete(X, term_slice, axis=1)
# Add HSGP components contribution to the linear predictor
hsgp_slices = []
for term_name, term in self.hsgp_terms.items():
term_slice = self.design.common.slices[term_name]
x_slice = X[:, term_slice]
hsgp_slices.append(term_slice)
term_aliased_name = get_aliased_name(term)
hsgp_to_stack_dims = (f"{term_aliased_name}_weights_dim",)
if term.scale_predictors:
maximum_distance = term.maximum_distance
else:
maximum_distance = 1
if term.by_levels is not None:
by_values = x_slice[:, -1].astype(int)
x_slice = x_slice[:, :-1]
phi = term.hsgp.prior_linearized(x_slice / maximum_distance)[0][by_values]
weights = posterior[term_aliased_name].stack(
__variables__=hsgp_to_stack_dims + (f"{term_aliased_name}_by",)
)
weights_dim = weights.coords[f"{term_aliased_name}_by"].to_numpy()[
np.newaxis
]
by_values_dim = by_values[:, np.newaxis]
mask = xarray.DataArray(
weights_dim == by_values_dim,
dims=[response_dim, f"{term_aliased_name}_by"],
)
hsgp_contribution = xarray.dot(
xarray.DataArray(phi, dims=[response_dim, "__variables__"]),
weights.where(mask, 0),
)
else:
phi = term.hsgp.prior_linearized(x_slice / maximum_distance)[0]
weights = posterior[term_aliased_name].stack(
__variables__=hsgp_to_stack_dims
)
hsgp_contribution = xarray.dot(
xarray.DataArray(phi, dims=[response_dim, "__variables__"]), weights
)
if hsgp_dict is not None:
hsgp_dict[term_name] = hsgp_contribution
linear_predictor += hsgp_contribution
if hsgp_slices:
X = np.delete(X, np.r_[tuple(hsgp_slices)], axis=1)
if self.common_terms or self.intercept_term:
X_terms = [get_aliased_name(term) for term in self.common_terms.values()]
if self.intercept_term:
X_terms.insert(0, get_aliased_name(self.intercept_term))
# Cache the stacked posterior - this is the expensive operation
cache_key = (tuple(X_terms), to_stack_dims)
if not hasattr(self, "_stacked_posterior_cache"):
self._stacked_posterior_cache = {}
if cache_key not in self._stacked_posterior_cache:
logger.info(f"Caching stacked posterior for predict_common {cache_key}")
self._stacked_posterior_cache[cache_key] = posterior[
X_terms
].to_stacked_array("__variables__", to_stack_dims)
b = self._stacked_posterior_cache[cache_key]
X = xarray.DataArray(X, dims=design_matrix_dims)
linear_predictor += xarray.dot(X, b)
if x_offsets:
linear_predictor += np.column_stack(x_offsets).sum(axis=1)[
:, np.newaxis, np.newaxis
]
return linear_predictor
DistributionalComponent.predict_common = _patched_predict_common