-
Notifications
You must be signed in to change notification settings - Fork 324
Narwhals support for CLV aggregation #1809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -14,11 +14,13 @@ | |||||||||
"""Utilities for the CLV module.""" | ||||||||||
|
||||||||||
import warnings | ||||||||||
from datetime import date, datetime | ||||||||||
from datetime import date, datetime, timedelta | ||||||||||
|
||||||||||
import narwhals as nw | ||||||||||
import numpy as np | ||||||||||
import pandas | ||||||||||
import xarray | ||||||||||
from narwhals.typing import IntoFrameT | ||||||||||
from numpy import datetime64 | ||||||||||
|
||||||||||
__all__ = [ | ||||||||||
|
@@ -156,6 +158,30 @@ def _squeeze_dims(x: xarray.DataArray): | |||||||||
return clv.transpose("chain", "draw", "customer_id") | ||||||||||
|
||||||||||
|
||||||||||
def _find_first_transactions_alternative( | ||||||||||
transactions: IntoFrameT, | ||||||||||
customer_id_col: str, | ||||||||||
datetime_col: str, | ||||||||||
monetary_value_col: str | None = None, | ||||||||||
datetime_format: str | None = None, | ||||||||||
) -> IntoFrameT: | ||||||||||
transactions = nw.from_native(transactions) | ||||||||||
|
||||||||||
first_date = transactions.group_by(customer_id_col).agg( | ||||||||||
first_date=nw.col(datetime_col).min() | ||||||||||
) | ||||||||||
|
||||||||||
agg_cols = [] if monetary_value_col is None else [nw.col(monetary_value_col).sum()] | ||||||||||
agg = transactions.group_by([customer_id_col, datetime_col]).agg(*agg_cols) | ||||||||||
|
||||||||||
return ( | ||||||||||
agg.join(first_date, on=customer_id_col) | ||||||||||
.with_columns(first=nw.col(datetime_col) == nw.col("first_date")) | ||||||||||
.drop("first_date") | ||||||||||
.to_native() | ||||||||||
) | ||||||||||
|
||||||||||
|
||||||||||
def _find_first_transactions( | ||||||||||
transactions: pandas.DataFrame, | ||||||||||
customer_id_col: str, | ||||||||||
|
@@ -264,6 +290,58 @@ def _find_first_transactions( | |||||||||
return period_transactions[select_columns] | ||||||||||
|
||||||||||
|
||||||||||
def rfm_summary_alternative( | ||||||||||
transactions: IntoFrameT, | ||||||||||
customer_id_col: str, | ||||||||||
datetime_col: str, | ||||||||||
monetary_value_col: str | None = None, | ||||||||||
datetime_format: str | None = None, | ||||||||||
observation_period_end: str | pandas.Period | datetime | None = None, | ||||||||||
time_scaler: float = 1.0, | ||||||||||
) -> IntoFrameT: | ||||||||||
transactions = nw.from_native(transactions) | ||||||||||
|
||||||||||
date = nw.col(datetime_col).cast(nw.Datetime) | ||||||||||
|
||||||||||
if observation_period_end is None: | ||||||||||
observation_period_end = transactions[datetime_col].cast(nw.Datetime).max() | ||||||||||
Comment on lines
+306
to
+307
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if this would work/is supported, but you might try to do:
Suggested change
to get the global max datetime value. This might also help to avoid this requirement:
|
||||||||||
|
||||||||||
repeated_transactions = _find_first_transactions_alternative( | ||||||||||
transactions, | ||||||||||
customer_id_col=customer_id_col, | ||||||||||
datetime_col=datetime_col, | ||||||||||
monetary_value_col=monetary_value_col, | ||||||||||
datetime_format=datetime_format, | ||||||||||
) | ||||||||||
|
||||||||||
# TODO: Support the various units | ||||||||||
divisor = timedelta(days=1) * time_scaler | ||||||||||
|
||||||||||
additional_cols = ( | ||||||||||
[] if monetary_value_col is None else [nw.col(monetary_value_col).mean()] | ||||||||||
) | ||||||||||
|
||||||||||
customers = ( | ||||||||||
nw.from_native(repeated_transactions) | ||||||||||
.group_by(customer_id_col) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For some time now, it should be possible to pass an expression so that you can avoid the renaming down in the pipeline, but it's definitely more of a personal preference 😇
Suggested change
|
||||||||||
.agg( | ||||||||||
*additional_cols, | ||||||||||
min=date.min(), | ||||||||||
max=date.max(), | ||||||||||
count=date.len(), | ||||||||||
) | ||||||||||
.with_columns( | ||||||||||
frequency=nw.col("count") - 1, | ||||||||||
recency=(nw.col("max") - nw.col("min")) / divisor, | ||||||||||
T=(observation_period_end - nw.col("min")) / divisor, | ||||||||||
) | ||||||||||
.rename({customer_id_col: "customer_id"}) | ||||||||||
# .select(["customer_id", "frequency", "recency"]) | ||||||||||
) | ||||||||||
|
||||||||||
return customers.to_native() | ||||||||||
|
||||||||||
|
||||||||||
def rfm_summary( | ||||||||||
transactions: pandas.DataFrame, | ||||||||||
customer_id_col: str, | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very tempting, but consider creating a new column between operations - I would be afraid that for pandas the casting happens multiple times instead of once