Skip to content

Narwhals support for CLV aggregation #1809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies:
- pydantic
- preliz>=0.20.0
- pyprojroot
- narwhals
# NOTE: Keep minimum pymc version in sync with ci.yml `OLDEST_PYMC_VERSION`
- pymc>=5.23.0
- nutpie>=0.15.1
Expand Down
80 changes: 79 additions & 1 deletion pymc_marketing/clv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
"""Utilities for the CLV module."""

import warnings
from datetime import date, datetime
from datetime import date, datetime, timedelta

import narwhals as nw
import numpy as np
import pandas
import xarray
from narwhals.typing import IntoFrameT
from numpy import datetime64

__all__ = [
Expand Down Expand Up @@ -156,6 +158,30 @@ def _squeeze_dims(x: xarray.DataArray):
return clv.transpose("chain", "draw", "customer_id")


def _find_first_transactions_alternative(
transactions: IntoFrameT,
customer_id_col: str,
datetime_col: str,
monetary_value_col: str | None = None,
datetime_format: str | None = None,
) -> IntoFrameT:
transactions = nw.from_native(transactions)

first_date = transactions.group_by(customer_id_col).agg(
first_date=nw.col(datetime_col).min()
)

agg_cols = [] if monetary_value_col is None else [nw.col(monetary_value_col).sum()]
agg = transactions.group_by([customer_id_col, datetime_col]).agg(*agg_cols)

return (
agg.join(first_date, on=customer_id_col)
.with_columns(first=nw.col(datetime_col) == nw.col("first_date"))
.drop("first_date")
.to_native()
)


def _find_first_transactions(
transactions: pandas.DataFrame,
customer_id_col: str,
Expand Down Expand Up @@ -264,6 +290,58 @@ def _find_first_transactions(
return period_transactions[select_columns]


def rfm_summary_alternative(
transactions: IntoFrameT,
customer_id_col: str,
datetime_col: str,
monetary_value_col: str | None = None,
datetime_format: str | None = None,
observation_period_end: str | pandas.Period | datetime | None = None,
time_scaler: float = 1.0,
) -> IntoFrameT:
transactions = nw.from_native(transactions)

date = nw.col(datetime_col).cast(nw.Datetime)

if observation_period_end is None:
observation_period_end = transactions[datetime_col].cast(nw.Datetime).max()

repeated_transactions = _find_first_transactions_alternative(
transactions,
customer_id_col=customer_id_col,
datetime_col=datetime_col,
monetary_value_col=monetary_value_col,
datetime_format=datetime_format,
)

# TODO: Support the various units
divisor = timedelta(days=1) * time_scaler

additional_cols = (
[] if monetary_value_col is None else [nw.col(monetary_value_col).mean()]
)

customers = (
nw.from_native(repeated_transactions)
.group_by(customer_id_col)
.agg(
*additional_cols,
min=date.min(),
max=date.max(),
count=date.len(),
)
.with_columns(
frequency=nw.col("count") - 1,
recency=(nw.col("max") - nw.col("min")) / divisor,
T=(observation_period_end - nw.col("min")) / divisor,
)
.rename({customer_id_col: "customer_id"})
# .select(["customer_id", "frequency", "recency"])
)

return customers.to_native()


def rfm_summary(
transactions: pandas.DataFrame,
customer_id_col: str,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"numpy>=1.17",
"pandas",
"pydantic>=2.1.0",
"narwhals",
# NOTE: Used as minimum pymc version with test.yml `OLDEST_PYMC_VERSION`
"pymc>=5.23.0",
"pytensor>=2.31.3",
Expand Down
48 changes: 48 additions & 0 deletions tests/clv/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
from pymc_marketing.clv.utils import (
_expected_cumulative_transactions,
_find_first_transactions,
_find_first_transactions_alternative,
_rfm_quartile_labels,
customer_lifetime_value,
rfm_segments,
rfm_summary,
rfm_summary_alternative,
rfm_train_test_split,
to_xarray,
)
Expand Down Expand Up @@ -286,6 +288,30 @@ def transaction_data(self) -> pd.DataFrame:
]
return pd.DataFrame(d, columns=["identifier", "date", "monetary_value"])

def test_alternative(self, transaction_data):
_ = _find_first_transactions(
transaction_data,
"identifier",
"date",
monetary_value_col="monetary_value",
)

_ = _find_first_transactions_alternative(
transaction_data,
"identifier",
"date",
monetary_value_col="monetary_value",
)

_ = rfm_summary_alternative(
transaction_data,
"identifier",
"date",
monetary_value_col="monetary_value",
)

assert 0

def test_find_first_transactions_test_period_end_none(self, transaction_data):
max_date = transaction_data["date"].max()
pd.testing.assert_frame_equal(
Expand Down Expand Up @@ -855,6 +881,28 @@ def test_rfm_quartile_labels(self):
frequency = _rfm_quartile_labels("f_quartile", 4)
assert frequency == range(1, 4)

def test_rfm_summary_with_time_scaler(self, transaction_data):
today = "2015-02-07"
actual = rfm_summary(
transaction_data,
"identifier",
"date",
observation_period_end=today,
time_scaler=10,
)
expected = pd.DataFrame(
[
[1, 1.0, 3.6, 3.7],
[2, 0.0, 0.0, 3.7],
[3, 2.0, 0.4, 3.7],
[4, 2.0, 2.0, 2.2],
[5, 2.0, 0.2, 2.2],
[6, 0.0, 0.0, 0.5],
],
columns=["customer_id", "frequency", "recency", "T"],
)
assert_frame_equal(actual, expected)


def test_expected_cumulative_transactions_dedups_inside_a_time_period(
fitted_bg, cdnow_trans
Expand Down
Loading