Skip to content

Performance improvements for predict #962

@humana

Description

@humana

Hi,

I'm using Bambi for online predictions on a rather large model. The introduction of sparse matrices has made a HUGE improvement for memory usage. I have however encountered 2 performance issues where predict() takes about 8 seconds for 1 row, where I'm using it like:

model.predict(
            self.idata,
            kind="response",
            data=data,
            sample_new_groups=False,
            inplace=False,
            random_seed=1337,
        )

I've managed to bring it down to milliseconds by doing 2 things:

  1. adding the cftime library (it seems to check if this module exists hundreds of times if you don't have it) - Cuts about 40% on the total predict call
  2. This is the biiiig one:

When using Model.predict() repeatedly (e.g., in a web service), we noticed predictions taking ~6-8 seconds each. Profiling revealed that the bottleneck is not the actual inference but xarray's to_stacked_array() operation in DistributionalComponent.predict_common:

b = posterior[X_terms].to_stacked_array("variables", to_stack_dims)

For models with many terms (ours has ~80), this operation is expensive because it iterates over each term, calling expand_dims, stack, and concat. In our case, this single line accounts for ~4 seconds per prediction.

To my understanding, the stacked posterior b only depends on posterior, X_terms, and to_stack_dims, none of which change between predictions for a given model. Only the design matrix X changes based on new input data.

Proposed solution: Cache the result of to_stacked_array() on the component instance:

  cache_key = (tuple(X_terms), to_stack_dims)
  if not hasattr(self, "_stacked_posterior_cache"):
      self._stacked_posterior_cache = {}
  if cache_key not in self._stacked_posterior_cache:
      self._stacked_posterior_cache[cache_key] = posterior[X_terms].to_stacked_array(
          "__variables__", to_stack_dims
      )
  b = self._stacked_posterior_cache[cache_key]

This makes the first prediction take the normal ~4s, but subsequent predictions become nearly instant. For serving models in production, this is a significant improvement.

For now we're patching the method, but it would be really cool if Bambi could incorporate such a change to help make online predictions work at scale out the box. The main use case is you're doing 1 row predict at a time, but thousands, so batching is not possible as it comes from a web request.

  _original_predict_common = DistributionalComponent.predict_common
  
  
   def _patched_predict_common(
      self, posterior, data, in_sample, to_stack_dims, design_matrix_dims, hsgp_dict
  ):
      from bambi.utils import get_aliased_name
  
      x_offsets = []
      linear_predictor = 0
      response_dim = design_matrix_dims[0]
  
      if in_sample:
          X = self.design.common.design_matrix
      else:
          X = self.design.common.evaluate_new_data(data).design_matrix
  
      # Add offset columns to their own design matrix and remove then from common matrix
      for term in self.offset_terms:
          term_slice = self.design.common.slices[term]
          x_offsets.append(X[:, term_slice])
          X = np.delete(X, term_slice, axis=1)
  
      # Add HSGP components contribution to the linear predictor
      hsgp_slices = []
      for term_name, term in self.hsgp_terms.items():
          term_slice = self.design.common.slices[term_name]
          x_slice = X[:, term_slice]
          hsgp_slices.append(term_slice)
          term_aliased_name = get_aliased_name(term)
          hsgp_to_stack_dims = (f"{term_aliased_name}_weights_dim",)
  
          if term.scale_predictors:
              maximum_distance = term.maximum_distance
          else:
              maximum_distance = 1
  
          if term.by_levels is not None:
              by_values = x_slice[:, -1].astype(int)
              x_slice = x_slice[:, :-1]
              phi = term.hsgp.prior_linearized(x_slice / maximum_distance)[0][by_values]
              weights = posterior[term_aliased_name].stack(
                  __variables__=hsgp_to_stack_dims + (f"{term_aliased_name}_by",)
              )
              weights_dim = weights.coords[f"{term_aliased_name}_by"].to_numpy()[
                  np.newaxis
              ]
              by_values_dim = by_values[:, np.newaxis]
              mask = xarray.DataArray(
                  weights_dim == by_values_dim,
                  dims=[response_dim, f"{term_aliased_name}_by"],
              )
              hsgp_contribution = xarray.dot(
                  xarray.DataArray(phi, dims=[response_dim, "__variables__"]),
                  weights.where(mask, 0),
              )
          else:
              phi = term.hsgp.prior_linearized(x_slice / maximum_distance)[0]
              weights = posterior[term_aliased_name].stack(
                  __variables__=hsgp_to_stack_dims
              )
              hsgp_contribution = xarray.dot(
                  xarray.DataArray(phi, dims=[response_dim, "__variables__"]), weights
              )
  
          if hsgp_dict is not None:
              hsgp_dict[term_name] = hsgp_contribution
  
          linear_predictor += hsgp_contribution
  
      if hsgp_slices:
          X = np.delete(X, np.r_[tuple(hsgp_slices)], axis=1)
  
      if self.common_terms or self.intercept_term:
          X_terms = [get_aliased_name(term) for term in self.common_terms.values()]
          if self.intercept_term:
              X_terms.insert(0, get_aliased_name(self.intercept_term))
  
          # Cache the stacked posterior - this is the expensive operation
          cache_key = (tuple(X_terms), to_stack_dims)
          if not hasattr(self, "_stacked_posterior_cache"):
              self._stacked_posterior_cache = {}
          if cache_key not in self._stacked_posterior_cache:
              logger.info(f"Caching stacked posterior for predict_common {cache_key}")
              self._stacked_posterior_cache[cache_key] = posterior[
                  X_terms
              ].to_stacked_array("__variables__", to_stack_dims)
          b = self._stacked_posterior_cache[cache_key]
  
          X = xarray.DataArray(X, dims=design_matrix_dims)
          linear_predictor += xarray.dot(X, b)
  
      if x_offsets:
          linear_predictor += np.column_stack(x_offsets).sum(axis=1)[
              :, np.newaxis, np.newaxis
          ]
  
      return linear_predictor
  
  
  DistributionalComponent.predict_common = _patched_predict_common

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions