@@ -42,7 +42,7 @@ class ShiftedBetaGeoModel(CLVModel):
4242 This model requires data to be summarized by *recency*, *T*, and *cohort* for each customer.
4343 Modeling assumptions require *1 <= recency <= T*, and *T >= 2*.
4444
45- First introduced by Fader & Hardie in [1]_, with additional expressions described in [2]_.
45+ First introduced by Fader & Hardie in [1]_, with additional expressions and enhancements described in [2]_ and [3] .
4646
4747 Parameters
4848 ----------
@@ -52,12 +52,19 @@ class ShiftedBetaGeoModel(CLVModel):
5252 * `recency`: Time period of last contract renewal. It should equal *T* for active customers.
5353 * `T`: Max observed time period in the cohort. All customers in a given cohort share the same value for *T*.
5454 * `cohort`: Customer cohort label
55+ * Additional columns for static covariates
5556 model_config : dict, optional
5657 Dictionary of model prior parameters:
57- * `a`: Shape parameter of dropout process; defaults to `phi_purchase` * `kappa_purchase`
58- * `b`: Shape parameter of dropout process; defaults to `1-phi_dropout` * `kappa_dropout`
59- * `phi_dropout`: Pooling prior; defaults to `Prior("Uniform", lower=0, upper=1, dims="cohort")`
60- * `kappa_dropout`: Pooling prior; defaults to `Prior("Pareto", alpha=1, m=1, dims="cohort")`
58+ * `alpha`: Shape parameter of dropout process (cohort-level);
59+ defaults to `phi` * `kappa`
60+ * `beta`: Shape parameter of dropout process (cohort-level);
61+ defaults to `(1-phi)` * `kappa`
62+ * `phi`: Pooling prior if `alpha` and `beta` are not provided;
63+ defaults to `Prior("Uniform", lower=0, upper=1, dims="cohort")`
64+ * `kappa`: Pooling prior if `alpha` and `beta` are not provided;
65+ defaults to `Prior("Pareto", alpha=1, m=1, dims="cohort")`
66+ * `dropout_coefficient`: Prior for covariate coefficients; defaults to `Prior("Normal", mu=0, sigma=1)`
67+ * `dropout_covariate_cols`: List of column names for customer-level covariates.
6168 sampler_config : dict, optional
6269 Dictionary of sampler parameters. Defaults to *None*.
6370
@@ -122,6 +129,44 @@ class ShiftedBetaGeoModel(CLVModel):
122129 discount_rate=0.05,
123130 ).sel(cohort=cohort_name)
124131
132+ # Example with customer-level covariates
133+ model_with_covariates = ShiftedBetaGeoModel(
134+ data=pd.DataFrame(
135+ {
136+ "customer_id": [1, 2, 3, ...],
137+ "recency": [8, 1, 4, ...],
138+ "T": [8, 5, 5, ...],
139+ "cohort": ["2025-02", "2025-04", "2025-04", ...],
140+ "channel_covariate": [1, 0, 1, ...],
141+ "rating_covariate": [
142+ 2.172,
143+ 1.234,
144+ 2.345,
145+ ...,
146+ ], # time-invariant
147+ }
148+ ),
149+ model_config={
150+ "dropout_coefficient": Prior("Normal", mu=0, sigma=2),
151+ "dropout_covariate_cols": ["channel_covariate", "rating_covariate"],
152+ },
153+ )
154+ model_with_covariates.fit()
155+
156+ # Predictions with covariates require covariate columns in prediction data
157+ pred_data = pd.DataFrame(
158+ {
159+ "customer_id": [...],
160+ "T": [...],
161+ "cohort": [...],
162+ "channel_covariate": [...],
163+ "rating_covariate": [...],
164+ }
165+ )
166+ retention_with_covariates = model_with_covariates.expected_retention_rate(
167+ data=pred_data, future_t=1
168+ )
169+
125170
126171 References
127172 ----------
@@ -131,6 +176,9 @@ class ShiftedBetaGeoModel(CLVModel):
131176 .. [2] Fader, P. S., & Hardie, B. G. (2010). "Customer-Base Valuation in a Contractual Setting:
132177 The Perils of Ignoring Heterogeneity." Marketing Science, 29(1), 85-93.
133178 https://faculty.wharton.upenn.edu/wp-content/uploads/2012/04/Fader_hardie_contractual_mksc_10.pdf
179+ .. [3] Fader, Peter & G. S. Hardie, Bruce (2007).
180+ "Incorporating Time-Invariant Covariates into the Pareto/NBD and BG/NBD Models".
181+ https://www.brucehardie.com/notes/019/time_invariant_covariates.pdf
134182 """
135183
136184 _model_type = "Shifted Beta-Geometric"
@@ -141,9 +189,25 @@ def __init__(
141189 model_config : ModelConfig | None = None ,
142190 sampler_config : dict | None = None ,
143191 ):
192+ super ().__init__ (
193+ data = data ,
194+ model_config = model_config ,
195+ sampler_config = sampler_config ,
196+ non_distributions = ["dropout_covariate_cols" ],
197+ )
198+
199+ # Extract covariate columns from model_config
200+ self .dropout_covariate_cols = list (self .model_config ["dropout_covariate_cols" ])
201+
144202 self ._validate_cols (
145203 data ,
146- required_cols = ["customer_id" , "recency" , "T" , "cohort" ],
204+ required_cols = [
205+ "customer_id" ,
206+ "recency" ,
207+ "T" ,
208+ "cohort" ,
209+ * self .dropout_covariate_cols ,
210+ ],
147211 must_be_unique = ["customer_id" ],
148212 )
149213
@@ -152,12 +216,6 @@ def __init__(
152216 ):
153217 raise ValueError ("Model fitting requires 1 <= recency <= T, and T >= 2." )
154218
155- super ().__init__ (
156- data = data ,
157- model_config = model_config ,
158- sampler_config = sampler_config ,
159- )
160-
161219 self ._validate_cohorts (self .data , check_param_dims = ("alpha" , "beta" ))
162220
163221 # Create cohort dim & coords
@@ -219,31 +277,89 @@ def default_model_config(self) -> ModelConfig:
219277 # Cohort-level hierarchical defaults (no covariates)
220278 "phi" : Prior ("Uniform" , lower = 0 , upper = 1 , dims = "cohort" ),
221279 "kappa" : Prior ("Pareto" , alpha = 1 , m = 1 , dims = "cohort" ),
280+ "dropout_coefficient" : Prior ("Normal" , mu = 0 , sigma = 1 ),
281+ "dropout_covariate_cols" : [],
222282 }
223283
224284 def build_model (self ) -> None : # type: ignore[override]
225285 """Build the model."""
226286 coords = {
227287 "customer_id" : self .data ["customer_id" ],
228288 "cohort" : self .cohorts ,
289+ "dropout_covariate" : self .dropout_covariate_cols ,
229290 }
230291 with pm .Model (coords = coords ) as self .model :
231- # Cohort-level behavior only
232- if "alpha" in self .model_config and "beta" in self .model_config :
233- alpha = self .model_config ["alpha" ].create_variable ("alpha" )
234- beta = self .model_config ["beta" ].create_variable ("beta" )
235- else :
236- # hierarchical pooling of dropout rate priors
237- phi = self .model_config ["phi" ].create_variable ("phi" )
238- kappa = self .model_config ["kappa" ].create_variable ("kappa" )
292+ if self .dropout_covariate_cols :
293+ # Customer-level behavior with covariates
294+ dropout_data = pm .Data (
295+ "dropout_data" ,
296+ self .data [self .dropout_covariate_cols ],
297+ dims = ["customer_id" , "dropout_covariate" ],
298+ )
239299
240- alpha = pm .Deterministic ("alpha" , phi * kappa , dims = "cohort" )
241- beta = pm .Deterministic ("beta" , (1.0 - phi ) * kappa , dims = "cohort" )
300+ # Get scale parameters (cohort-level baseline)
301+ if "alpha" in self .model_config and "beta" in self .model_config :
302+ alpha_scale = self .model_config ["alpha" ].create_variable (
303+ "alpha_scale"
304+ )
305+ beta_scale = self .model_config ["beta" ].create_variable ("beta_scale" )
306+ else :
307+ # hierarchical pooling of dropout rate priors
308+ phi = self .model_config ["phi" ].create_variable ("phi" )
309+ kappa = self .model_config ["kappa" ].create_variable ("kappa" )
310+
311+ alpha_scale = pm .Deterministic (
312+ "alpha_scale" , phi * kappa , dims = "cohort"
313+ )
314+ beta_scale = pm .Deterministic (
315+ "beta_scale" , (1.0 - phi ) * kappa , dims = "cohort"
316+ )
317+
318+ # Get covariate coefficients
319+ self .model_config ["dropout_coefficient" ].dims = "dropout_covariate"
320+ dropout_coefficient_alpha = self .model_config [
321+ "dropout_coefficient"
322+ ].create_variable ("dropout_coefficient_alpha" )
323+ dropout_coefficient_beta = self .model_config [
324+ "dropout_coefficient"
325+ ].create_variable ("dropout_coefficient_beta" )
326+
327+ # Apply covariate effects to get customer-level parameters
328+ # expressions adapted from BG/NBD covariate extensions on p2 of [3]_:
329+ # https://www.brucehardie.com/notes/019/time_invariant_covariates.pdf
330+ alpha = pm .Deterministic (
331+ "alpha" ,
332+ alpha_scale [self .cohort_idx ]
333+ * pm .math .exp (
334+ - pm .math .dot (dropout_data , dropout_coefficient_alpha )
335+ ),
336+ dims = "customer_id" ,
337+ )
338+ beta = pm .Deterministic (
339+ "beta" ,
340+ beta_scale [self .cohort_idx ]
341+ * pm .math .exp (- pm .math .dot (dropout_data , dropout_coefficient_beta )),
342+ dims = "customer_id" ,
343+ )
242344
243- dropout = ShiftedBetaGeometric .dist (
244- alpha [self .cohort_idx ],
245- beta [self .cohort_idx ],
246- )
345+ dropout = ShiftedBetaGeometric .dist (alpha , beta )
346+ else :
347+ # Cohort-level behavior only, no covariates
348+ if "alpha" in self .model_config and "beta" in self .model_config :
349+ alpha = self .model_config ["alpha" ].create_variable ("alpha" )
350+ beta = self .model_config ["beta" ].create_variable ("beta" )
351+ else :
352+ # hierarchical pooling of dropout rate priors
353+ phi = self .model_config ["phi" ].create_variable ("phi" )
354+ kappa = self .model_config ["kappa" ].create_variable ("kappa" )
355+
356+ alpha = pm .Deterministic ("alpha" , phi * kappa , dims = "cohort" )
357+ beta = pm .Deterministic ("beta" , (1.0 - phi ) * kappa , dims = "cohort" )
358+
359+ dropout = ShiftedBetaGeometric .dist (
360+ alpha [self .cohort_idx ],
361+ beta [self .cohort_idx ],
362+ )
247363
248364 pm .Censored (
249365 "dropout" ,
@@ -270,6 +386,7 @@ def _extract_predictive_variables(
270386 required_cols = [
271387 "customer_id" ,
272388 * customer_varnames ,
389+ * self .dropout_covariate_cols ,
273390 ],
274391 must_be_unique = ["customer_id" ],
275392 )
@@ -280,28 +397,76 @@ def _extract_predictive_variables(
280397 "T must be a non-zero, positive whole number." ,
281398 )
282399
400+ # Validate cohorts in prediction data match any or all cohorts used to fit model
283401 cohorts_present = self ._validate_cohorts (pred_data , check_pred_data = True )
284402
285- # Extract alpha and beta parameters only for cohorts present in the data
403+ # Use cohorts in prediction data to extract only cohort-level parameters
286404 pred_cohorts = xarray .DataArray (
287405 cohorts_present .values ,
288406 dims = ("cohort" ,),
289407 coords = {"cohort" : cohorts_present .values },
290408 )
291- alpha_pred = self .fit_result ["alpha" ].sel (cohort = pred_cohorts )
292- beta_pred = self .fit_result ["beta" ].sel (cohort = pred_cohorts )
293409
294- # Create a cohort-by-customer DataArray to map alpha and beta cohort parameters to each customer
410+ # Create a cohort-by-customer array to map cohort parameters to each customer
295411 customer_cohort_map = pred_data .set_index ("customer_id" )["cohort" ]
296412
297- customer_cohort_mapping = xarray .DataArray (
298- customer_cohort_map .values ,
299- dims = ("customer_id" ,),
300- coords = {"customer_id" : customer_cohort_map .index },
301- name = "customer_cohort_mapping" ,
302- )
303- alpha_pred = alpha_pred .sel (cohort = customer_cohort_mapping )
304- beta_pred = beta_pred .sel (cohort = customer_cohort_mapping )
413+ if self .dropout_covariate_cols :
414+ # Get alpha and beta scale parameters for each cohort
415+ alpha_cohort = self .fit_result ["alpha_scale" ].sel (cohort = pred_cohorts )
416+ beta_cohort = self .fit_result ["beta_scale" ].sel (cohort = pred_cohorts )
417+ # Get dropout covariate coefficients
418+ dropout_coefficient_alpha = self .fit_result ["dropout_coefficient_alpha" ]
419+ dropout_coefficient_beta = self .fit_result ["dropout_coefficient_beta" ]
420+
421+ # Reconstruct customer-level alpha and beta with covariates
422+ # Create covariate xarray
423+ dropout_xarray = xarray .DataArray (
424+ pred_data [self .dropout_covariate_cols ].values ,
425+ dims = ["customer_id" , "dropout_covariate" ],
426+ coords = {
427+ "customer_id" : pred_data ["customer_id" ],
428+ "dropout_covariate" : self .dropout_covariate_cols ,
429+ },
430+ )
431+
432+ # Map cohort indices for each customer
433+ pred_cohort_idx = pd .Categorical (
434+ customer_cohort_map .values , categories = self .cohorts
435+ ).codes
436+
437+ # Reconstruct customer-level parameters
438+ alpha_pred = alpha_cohort .isel (
439+ cohort = xarray .DataArray (pred_cohort_idx , dims = "customer_id" )
440+ ) * np .exp (
441+ - xarray .dot (
442+ dropout_xarray , dropout_coefficient_alpha , dim = "dropout_covariate"
443+ )
444+ )
445+ alpha_pred .name = "alpha"
446+
447+ beta_pred = beta_cohort .isel (
448+ cohort = xarray .DataArray (pred_cohort_idx , dims = "customer_id" )
449+ ) * np .exp (
450+ - xarray .dot (
451+ dropout_xarray , dropout_coefficient_beta , dim = "dropout_covariate"
452+ )
453+ )
454+ beta_pred .name = "beta"
455+
456+ else :
457+ # Get alpha and beta parameters for each cohort
458+ alpha_cohort = self .fit_result ["alpha" ].sel (cohort = pred_cohorts )
459+ beta_cohort = self .fit_result ["beta" ].sel (cohort = pred_cohorts )
460+
461+ # Map cohorts to customer_id for alpha and beta
462+ customer_cohort_mapping = xarray .DataArray (
463+ customer_cohort_map .values ,
464+ dims = ("customer_id" ,),
465+ coords = {"customer_id" : customer_cohort_map .index },
466+ name = "customer_cohort_mapping" ,
467+ )
468+ alpha_pred = alpha_cohort .sel (cohort = customer_cohort_mapping )
469+ beta_pred = beta_cohort .sel (cohort = customer_cohort_mapping )
305470
306471 # Add cohorts as non-dimensional coordinates to merge with predictive variables
307472 alpha_pred = alpha_pred .assign_coords (
@@ -362,6 +527,7 @@ def expected_retention_rate(
362527 * `customer_id`: Unique customer identifier
363528 * `T`: Number of time periods customer has been active
364529 * `cohort`: Customer cohort label
530+ * Covariate columns specified in `dropout_covariate_cols` (if using covariates)
365531
366532 References
367533 ----------
@@ -410,6 +576,7 @@ def expected_probability_alive(
410576 * `customer_id`: Unique customer identifier
411577 * `T`: Number of time periods customer has been active
412578 * `cohort`: Customer cohort label
579+ * Covariate columns specified in `dropout_covariate_cols` (if using covariates)
413580
414581 References
415582 ----------
@@ -467,6 +634,7 @@ def expected_residual_lifetime(
467634 * `customer_id`: Unique customer identifier
468635 * `T`: Number of time periods customer has been active
469636 * `cohort`: Customer cohort label
637+ * Covariate columns specified in `dropout_covariate_cols` (if using covariates)
470638
471639 References
472640 ----------
@@ -519,6 +687,7 @@ def expected_retention_elasticity(
519687 * `customer_id`: Unique customer identifier
520688 * `T`: Number of time periods customer has been active
521689 * `cohort`: Customer cohort label
690+ * Covariate columns specified in `dropout_covariate_cols` (if using covariates)
522691
523692 References
524693 ----------
0 commit comments