Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Literal, cast

import arviz as az
import numpy as np
import pandas as pd
import pymc as pm
from pydantic import ConfigDict, InstanceOf, validate_call
Expand All @@ -44,6 +45,13 @@ def __init__(
non_distributions: list[str] | None = None,
):
self.data = data
if {"frequency", "recency", "T"}.issubset(data.columns):
self._check_inputs(
frequency=np.asarray(data["frequency"]),
recency=np.asarray(data["recency"]),
T=np.asarray(data["T"]),
check_frequency=getattr(self, "_check_frequency", True),
)
model_config = model_config or {}

deprecated_keys = [key for key in model_config if key.endswith("_prior")]
Expand Down Expand Up @@ -90,6 +98,42 @@ def _validate_cols(
if data[col].nunique() != 1:
raise ValueError(f"Column {col} has non-homogeneous entries")

@staticmethod
def _check_inputs(
frequency: np.ndarray,
recency: np.ndarray,
T: np.ndarray,
check_frequency: bool = True,
) -> None:
"""Validate input data for CLV models.

Parameters
----------
frequency : array-like
Number of repeat purchases.
recency : array-like
Time of most recent purchase.
T : array-like
Total observation time.
check_frequency : bool, default True
If True, validate that frequency >= 0.
Set to False for ModifiedBetaGeoModel which supports zero-frequency.
"""
if (
np.any(np.isnan(frequency))
or np.any(np.isnan(recency))
or np.any(np.isnan(T))
):
raise ValueError("Input data contains NaN values.")
if check_frequency and np.any(frequency < 0):
raise ValueError("Frequency must be >= 0.")
if np.any(recency < 0):
raise ValueError("Recency must be >= 0.")
if np.any(T < 0):
raise ValueError("T must be >= 0.")
if np.any(recency > T):
raise ValueError("Recency cannot be greater than T.")

def __repr__(self) -> str:
"""Representation of the model."""
if not hasattr(self, "model"):
Expand Down
9 changes: 4 additions & 5 deletions pymc_marketing/clv/models/gamma_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def expected_customer_lifetime_value(
Parameters
----------
transaction_model : ~CLVModel
Predictive model for future transactions. `BetaGeoModel` and `ParetoNBDModel` are currently supported.
Predictive model for future transactions. `BetaGeoModel`,
`ModifiedBetaGeoModel`, and `ParetoNBDModel` are currently supported.
data : ~pandas.DataFrame
DataFrame containing the following columns:

Expand All @@ -197,11 +198,9 @@ def expected_customer_lifetime_value(
DataArray containing estimated customer lifetime values

"""
# Use Gamma-Gamma estimates for the expected_spend values
data = data.copy()
predicted_monetary_value = self.expected_customer_spend(data=data)
data.loc[:, "future_spend"] = predicted_monetary_value.mean(
("chain", "draw")
).copy()
data["future_spend"] = predicted_monetary_value.mean(("chain", "draw")).copy()

return customer_lifetime_value(
transaction_model=transaction_model,
Expand Down
14 changes: 10 additions & 4 deletions pymc_marketing/clv/models/shifted_beta_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,16 @@ def __init__(
must_be_unique=["customer_id"],
)

if np.any(
(data["recency"] < 1) | (data["recency"] > data["T"]) | (data["T"] < 2)
):
raise ValueError("Model fitting requires 1 <= recency <= T, and T >= 2.")
self._check_inputs(
frequency=np.zeros(len(data)),
recency=np.asarray(data["recency"]),
T=np.asarray(data["T"]),
check_frequency=False,
)
if np.any(data["recency"] < 1):
raise ValueError("Recency must be >= 1.")
if np.any(data["T"] < 2):
raise ValueError("T must be >= 2.")

self._validate_cohorts(self.data, check_param_dims=("alpha", "beta"))

Expand Down
3 changes: 2 additions & 1 deletion pymc_marketing/clv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def customer_lifetime_value(
Parameters
----------
transaction_model : ~CLVModel
Predictive model for future transactions. `BetaGeoModel` and `ParetoNBDModel` are currently supported.
Predictive model for future transactions. `BetaGeoModel`,
`ModifiedBetaGeoModel`, and `ParetoNBDModel` are currently supported.
data : ~pandas.DataFrame
DataFrame containing the following columns:

Expand Down
76 changes: 76 additions & 0 deletions tests/clv/models/test_beta_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,82 @@ def test_customer_id_duplicate(self):
data=data,
)

def test_check_inputs_frequency_negative(self):
data_invalid = pd.DataFrame(
{
"customer_id": [1],
"frequency": [-1],
"recency": [0],
"T": [10],
}
)
# Update match to your new centralized string
with pytest.raises(ValueError, match="Frequency must be >= 0"):
BetaGeoModel(data=data_invalid)

def test_check_inputs_recency_negative(self):
data_invalid = pd.DataFrame(
{
"customer_id": [1],
"frequency": [0],
"recency": [-1],
"T": [10],
}
)
# Update match to your new centralized string
with pytest.raises(ValueError, match="Recency must be >= 0"):
BetaGeoModel(data=data_invalid)

def test_check_inputs_T_negative(self):
data_invalid = pd.DataFrame(
{
"customer_id": [1],
"frequency": [0],
"recency": [0],
"T": [-1],
}
)
# Update match to your new centralized string
with pytest.raises(ValueError, match="T must be >= 0"):
BetaGeoModel(data=data_invalid)

def test_check_inputs_recency_greater_than_T(self):
data_invalid = pd.DataFrame(
{
"customer_id": [1],
"frequency": [0],
"recency": [11],
"T": [10],
}
)
# Update match to your new centralized string
with pytest.raises(ValueError, match="Recency cannot be greater than T"):
BetaGeoModel(data=data_invalid)

def test_check_inputs_valid_data_passes(self):
"""Valid data should pass _check_inputs and allow model construction."""
data_valid = pd.DataFrame(
{
"customer_id": np.asarray([1, 2]),
"frequency": np.asarray([0, 3]),
"recency": np.asarray([0, 5]),
"T": np.asarray([10, 10]),
}
)
model = BetaGeoModel(data=data_valid)
assert model.data is not None
# Edge case: recency == T is valid
data_edge = pd.DataFrame(
{
"customer_id": np.asarray([1]),
"frequency": np.asarray([1]),
"recency": np.asarray([10]),
"T": np.asarray([10]),
}
)
model_edge = BetaGeoModel(data=data_edge)
assert model_edge.data is not None

@pytest.mark.parametrize(
"frequency, recency, logp_value",
[
Expand Down
114 changes: 114 additions & 0 deletions tests/clv/models/test_gamma_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pandas as pd
import pymc as pm
import pytest
import xarray as xr
from pymc_extras.prior import Prior

from pymc_marketing.clv.models.gamma_gamma import (
Expand Down Expand Up @@ -80,6 +81,119 @@ def setup_class(cls):


class TestGammaGammaModel(BaseTestGammaGammaModel):
@patch("pymc_marketing.clv.models.gamma_gamma.customer_lifetime_value")
def test_expected_customer_lifetime_value_uses_expected_spend(self, mock_clv):
"""Test that expected_customer_lifetime_value always uses expected_customer_spend."""
# Initialize a basic model configuration
model = GammaGammaModel(data=self.data)

# We don't need a real transaction model for this test, a string acts as a dummy
dummy_transaction_model = "dummy_transaction_model"

# 1. Test default behavior (monetary_value = None)
# We patch expected_customer_spend so we don't need to actually fit the model
dummy_expected_spend = xr.DataArray(
np.full((1, 1, len(self.data)), 15.0), dims=("chain", "draw", "customer_id")
)

with patch.object(
model, "expected_customer_spend", return_value=dummy_expected_spend
) as mock_expected_spend:
model.expected_customer_lifetime_value(
transaction_model=dummy_transaction_model,
data=self.data,
)

# Assert the model calculated expected spend itself
mock_expected_spend.assert_called_once()

# Assert it passed the calculated values to the utility function
called_data = mock_clv.call_args[1]["data"]
assert "future_spend" in called_data.columns
np.testing.assert_array_equal(called_data["future_spend"], 15.0)

@pytest.mark.parametrize(
"invalid_data, expected_error_match",
[
(
pd.DataFrame(
{
"customer_id": [1],
"monetary_value": [23.5],
"frequency": [-1],
"recency": [5],
"T": [10],
}
),
"Frequency must be >= 0.",
),
(
pd.DataFrame(
{
"customer_id": [1],
"monetary_value": [19.3],
"frequency": [2],
"recency": [-5],
"T": [10],
}
),
"Recency must be >= 0.",
),
(
pd.DataFrame(
{
"customer_id": [1],
"monetary_value": [11.2],
"frequency": [2],
"recency": [5],
"T": [-10],
}
),
"T must be >= 0.",
),
(
pd.DataFrame(
{
"customer_id": [1],
"monetary_value": [100.5],
"frequency": [2],
"recency": [15],
"T": [10],
}
),
"Recency cannot be greater than T.",
),
(
pd.DataFrame(
{
"customer_id": [1],
"monetary_value": [23.5],
"frequency": pd.Series([np.nan], dtype="Int64"),
"recency": [5],
"T": [10],
}
),
"Input data contains NaN values.",
),
(
pd.DataFrame(
{
"customer_id": [1],
"monetary_value": [19.3],
"frequency": [2],
"recency": pd.Series([np.nan], dtype="Int64"),
"T": [10],
}
),
"Input data contains NaN values.",
),
],
)
def test_check_inputs_validation(self, invalid_data, expected_error_match):
"""Test that _check_inputs correctly catches invalid frequency, recency, and T values."""
with pytest.raises(ValueError, match=expected_error_match):
GammaGammaModel(data=invalid_data)

def test_missing_columns(self):
data_invalid = self.data.drop(columns="customer_id")
with pytest.raises(
Expand Down
58 changes: 58 additions & 0 deletions tests/clv/models/test_modified_beta_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,64 @@ def setup_class(cls):

mock_fit(cls.model, cls.chains, cls.draws, cls.rng)

@pytest.mark.parametrize(
"invalid_data, expected_error_match",
[
(
pd.DataFrame(
{"customer_id": [1], "frequency": [2], "recency": [-5], "T": [10]}
),
"Recency must be >= 0.",
),
(
pd.DataFrame(
{"customer_id": [1], "frequency": [2], "recency": [5], "T": [-10]}
),
"T must be >= 0.",
),
(
pd.DataFrame(
{"customer_id": [1], "frequency": [2], "recency": [15], "T": [10]}
),
"Recency cannot be greater than T.",
),
(
pd.DataFrame(
{
"customer_id": [1],
"frequency": pd.Series([np.nan], dtype="Int64"),
"recency": [5],
"T": [10],
}
),
"Input data contains NaN values.",
),
(
pd.DataFrame(
{
"customer_id": [1],
"frequency": [2],
"recency": pd.Series([np.nan], dtype="Int64"),
"T": [10],
}
),
"Input data contains NaN values.",
),
],
)
def test_check_inputs_validation(self, invalid_data, expected_error_match):
"""Test that _check_inputs correctly catches invalid frequency, recency, and T values."""
with pytest.raises(ValueError, match=expected_error_match):
ModifiedBetaGeoModel(data=invalid_data)

def test_expected_customer_lifetime_value_removed(self):
"""ModifiedBetaGeoModel no longer exposes expected_customer_lifetime_value."""
dummy_data = pd.DataFrame(
{"customer_id": [1], "frequency": [2], "recency": [5], "T": [10]}
)
model = ModifiedBetaGeoModel(data=dummy_data)
assert not hasattr(model, "expected_customer_lifetime_value")

@pytest.fixture(scope="class")
def model_config(self):
return {
Expand Down
Loading