diff --git a/environment.yml b/environment.yml index 4751a37b..440297e1 100644 --- a/environment.yml +++ b/environment.yml @@ -15,6 +15,7 @@ dependencies: - pydantic - preliz>=0.20.0 - pyprojroot +- narwhals # NOTE: Keep minimum pymc version in sync with ci.yml `OLDEST_PYMC_VERSION` - pymc>=5.23.0 - nutpie>=0.15.1 diff --git a/pymc_marketing/clv/utils.py b/pymc_marketing/clv/utils.py index e7773834..48058810 100644 --- a/pymc_marketing/clv/utils.py +++ b/pymc_marketing/clv/utils.py @@ -14,11 +14,13 @@ """Utilities for the CLV module.""" import warnings -from datetime import date, datetime +from datetime import date, datetime, timedelta +import narwhals as nw import numpy as np import pandas import xarray +from narwhals.typing import IntoFrameT from numpy import datetime64 __all__ = [ @@ -156,6 +158,30 @@ def _squeeze_dims(x: xarray.DataArray): return clv.transpose("chain", "draw", "customer_id") +def _find_first_transactions_alternative( + transactions: IntoFrameT, + customer_id_col: str, + datetime_col: str, + monetary_value_col: str | None = None, + datetime_format: str | None = None, +) -> IntoFrameT: + transactions = nw.from_native(transactions) + + first_date = transactions.group_by(customer_id_col).agg( + first_date=nw.col(datetime_col).min() + ) + + agg_cols = [] if monetary_value_col is None else [nw.col(monetary_value_col).sum()] + agg = transactions.group_by([customer_id_col, datetime_col]).agg(*agg_cols) + + return ( + agg.join(first_date, on=customer_id_col) + .with_columns(first=nw.col(datetime_col) == nw.col("first_date")) + .drop("first_date") + .to_native() + ) + + def _find_first_transactions( transactions: pandas.DataFrame, customer_id_col: str, @@ -264,6 +290,58 @@ def _find_first_transactions( return period_transactions[select_columns] +def rfm_summary_alternative( + transactions: IntoFrameT, + customer_id_col: str, + datetime_col: str, + monetary_value_col: str | None = None, + datetime_format: str | None = None, + observation_period_end: str | pandas.Period | datetime | None = None, + time_scaler: float = 1.0, +) -> IntoFrameT: + transactions = nw.from_native(transactions) + + date = nw.col(datetime_col).cast(nw.Datetime) + + if observation_period_end is None: + observation_period_end = transactions[datetime_col].cast(nw.Datetime).max() + + repeated_transactions = _find_first_transactions_alternative( + transactions, + customer_id_col=customer_id_col, + datetime_col=datetime_col, + monetary_value_col=monetary_value_col, + datetime_format=datetime_format, + ) + + # TODO: Support the various units + divisor = timedelta(days=1) * time_scaler + + additional_cols = ( + [] if monetary_value_col is None else [nw.col(monetary_value_col).mean()] + ) + + customers = ( + nw.from_native(repeated_transactions) + .group_by(customer_id_col) + .agg( + *additional_cols, + min=date.min(), + max=date.max(), + count=date.len(), + ) + .with_columns( + frequency=nw.col("count") - 1, + recency=(nw.col("max") - nw.col("min")) / divisor, + T=(observation_period_end - nw.col("min")) / divisor, + ) + .rename({customer_id_col: "customer_id"}) + # .select(["customer_id", "frequency", "recency"]) + ) + + return customers.to_native() + + def rfm_summary( transactions: pandas.DataFrame, customer_id_col: str, diff --git a/pyproject.toml b/pyproject.toml index a884a905..cea2fefb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "numpy>=1.17", "pandas", "pydantic>=2.1.0", + "narwhals", # NOTE: Used as minimum pymc version with test.yml `OLDEST_PYMC_VERSION` "pymc>=5.23.0", "pytensor>=2.31.3", diff --git a/tests/clv/test_utils.py b/tests/clv/test_utils.py index 8acbea52..00dcfe8a 100644 --- a/tests/clv/test_utils.py +++ b/tests/clv/test_utils.py @@ -24,10 +24,12 @@ from pymc_marketing.clv.utils import ( _expected_cumulative_transactions, _find_first_transactions, + _find_first_transactions_alternative, _rfm_quartile_labels, customer_lifetime_value, rfm_segments, rfm_summary, + rfm_summary_alternative, rfm_train_test_split, to_xarray, ) @@ -286,6 +288,30 @@ def transaction_data(self) -> pd.DataFrame: ] return pd.DataFrame(d, columns=["identifier", "date", "monetary_value"]) + def test_alternative(self, transaction_data): + _ = _find_first_transactions( + transaction_data, + "identifier", + "date", + monetary_value_col="monetary_value", + ) + + _ = _find_first_transactions_alternative( + transaction_data, + "identifier", + "date", + monetary_value_col="monetary_value", + ) + + _ = rfm_summary_alternative( + transaction_data, + "identifier", + "date", + monetary_value_col="monetary_value", + ) + + assert 0 + def test_find_first_transactions_test_period_end_none(self, transaction_data): max_date = transaction_data["date"].max() pd.testing.assert_frame_equal( @@ -855,6 +881,28 @@ def test_rfm_quartile_labels(self): frequency = _rfm_quartile_labels("f_quartile", 4) assert frequency == range(1, 4) + def test_rfm_summary_with_time_scaler(self, transaction_data): + today = "2015-02-07" + actual = rfm_summary( + transaction_data, + "identifier", + "date", + observation_period_end=today, + time_scaler=10, + ) + expected = pd.DataFrame( + [ + [1, 1.0, 3.6, 3.7], + [2, 0.0, 0.0, 3.7], + [3, 2.0, 0.4, 3.7], + [4, 2.0, 2.0, 2.2], + [5, 2.0, 0.2, 2.2], + [6, 0.0, 0.0, 0.5], + ], + columns=["customer_id", "frequency", "recency", "T"], + ) + assert_frame_equal(actual, expected) + def test_expected_cumulative_transactions_dedups_inside_a_time_period( fitted_bg, cdnow_trans