Skip to content

Commit 660686b

Browse files
Add BetaGeoBetaBinomModel (#1031)
* beta_geo_beta_binom.py copy from prev pr * test_beta_geo_beta_binom.py copy from prev PR * beta_geo_beta_binom imports * basic.py validate homogeneous T * copy _logp fix from prev PR * notebook and WIP _distribution_new_customers * test_distribution_new_customers * TODOs and test coverage * docstrings * docstrings * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 561e5c3 commit 660686b

File tree

8 files changed

+7310
-446
lines changed

8 files changed

+7310
-446
lines changed

docs/source/notebooks/clv/dev/beta_geo_beta_binom.ipynb

Lines changed: 6072 additions & 426 deletions
Large diffs are not rendered by default.

pymc_marketing/clv/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""CLV models and utilities."""
1515

1616
from pymc_marketing.clv.models import (
17+
BetaGeoBetaBinomModel,
1718
BetaGeoModel,
1819
GammaGammaModel,
1920
GammaGammaModelIndividual,
@@ -34,6 +35,7 @@
3435

3536
__all__ = (
3637
"BetaGeoModel",
38+
"BetaGeoBetaBinomModel",
3739
"ParetoNBDModel",
3840
"GammaGammaModel",
3941
"GammaGammaModelIndividual",

pymc_marketing/clv/distributions.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -601,23 +601,17 @@ def logp(value, alpha, beta, gamma, delta, T):
601601
"""Log-likelihood of the distribution."""
602602
t_x = pt.atleast_1d(value[..., 0])
603603
x = pt.atleast_1d(value[..., 1])
604-
scalar_case = t_x.type.broadcastable == (True,)
605604

606605
for param in (t_x, x, alpha, beta, gamma, delta, T):
607606
if param.type.ndim > 1:
608607
raise NotImplementedError(
609608
f"BetaGeoBetaBinom logp only implemented for vector parameters, got ndim={param.type.ndim}"
610609
)
611-
if scalar_case:
612-
if param.type.broadcastable == (False,):
613-
raise NotImplementedError(
614-
f"Parameter {param} cannot be larger than scalar value"
615-
)
616610

617611
# Broadcast all the parameters so they are sequences.
618612
# Potentially inefficient, but otherwise ugly logic needed to unpack arguments in the scan function,
619613
# since sequences always precede non-sequences.
620-
_, alpha, beta, gamma, delta, T = pt.broadcast_arrays(
614+
t_x, alpha, beta, gamma, delta, T = pt.broadcast_arrays(
621615
t_x, alpha, beta, gamma, delta, T
622616
)
623617

pymc_marketing/clv/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from pymc_marketing.clv.models.basic import CLVModel
1818
from pymc_marketing.clv.models.beta_geo import BetaGeoModel
19+
from pymc_marketing.clv.models.beta_geo_beta_binom import BetaGeoBetaBinomModel
1920
from pymc_marketing.clv.models.gamma_gamma import (
2021
GammaGammaModel,
2122
GammaGammaModelIndividual,
@@ -25,6 +26,7 @@
2526

2627
__all__ = (
2728
"CLVModel",
29+
"BetaGeoBetaBinomModel",
2830
"GammaGammaModel",
2931
"GammaGammaModelIndividual",
3032
"BetaGeoModel",

pymc_marketing/clv/models/basic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def _validate_cols(
6161
data: pd.DataFrame,
6262
required_cols: Sequence[str],
6363
must_be_unique: Sequence[str] = (),
64+
must_be_homogenous: Sequence[str] = (),
6465
):
6566
existing_columns = set(data.columns)
6667
n = data.shape[0]
@@ -71,6 +72,11 @@ def _validate_cols(
7172
if required_col in must_be_unique:
7273
if data[required_col].nunique() != n:
7374
raise ValueError(f"Column {required_col} has duplicate entries")
75+
if required_col in must_be_homogenous:
76+
if data[required_col].nunique() != 1:
77+
raise ValueError(
78+
f"Column {required_col} has non-homogeneous entries"
79+
)
7480

7581
def __repr__(self) -> str:
7682
"""Representation of the model."""

0 commit comments

Comments
 (0)