Skip to content

Commit eaabc67

Browse files
fix(clv): Improve validation for missing columns in CLVModel (#1851)
* fix(clv): Improve validation for missing columns in CLVModel Improves the column validation check to find all missing columns at once and raises a single, informative ValueError listing all of them. This directly addresses the goal of issue #1734 by providing a much better user experience when data is incomplete. * style: Apply pre-commit fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f51f9f5 commit eaabc67

File tree

2 files changed

+28
-12
lines changed

2 files changed

+28
-12
lines changed

pymc_marketing/clv/models/basic.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,20 +75,22 @@ def _validate_cols(
7575
must_be_unique: Sequence[str] = (),
7676
must_be_homogenous: Sequence[str] = (),
7777
):
78-
existing_columns = set(data.columns)
78+
missing = set(required_cols).difference(data.columns)
79+
if missing:
80+
raise ValueError(
81+
"The following required columns are missing from the "
82+
f"input data: {sorted(list(missing))}"
83+
)
84+
7985
n = data.shape[0]
8086

81-
for required_col in required_cols:
82-
if required_col not in existing_columns:
83-
raise ValueError(f"Required column {required_col} missing")
84-
if required_col in must_be_unique:
85-
if data[required_col].nunique() != n:
86-
raise ValueError(f"Column {required_col} has duplicate entries")
87-
if required_col in must_be_homogenous:
88-
if data[required_col].nunique() != 1:
89-
raise ValueError(
90-
f"Column {required_col} has non-homogeneous entries"
91-
)
87+
for col in required_cols:
88+
if col in must_be_unique:
89+
if data[col].nunique() != n:
90+
raise ValueError(f"Column {col} has duplicate entries")
91+
if col in must_be_homogenous:
92+
if data[col].nunique() != 1:
93+
raise ValueError(f"Column {col} has non-homogeneous entries")
9294

9395
def __repr__(self) -> str:
9496
"""Representation of the model."""

tests/clv/models/test_basic.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,17 @@ def test_deprecation_warning_on_old_config(self):
274274
model = CLVModelTest(model_config=old_model_config)
275275

276276
assert model.model_config == {"x": Prior("Normal", mu=0, sigma=1)}
277+
278+
def test_validate_cols_reports_all_missing_columns(self):
279+
"""Test _validate_cols raises a single ValueError listing all missing columns."""
280+
required = ("customer_id", "frequency", "recency", "T")
281+
data = pd.DataFrame(
282+
{
283+
"customer_id": [1, 2, 3],
284+
"frequency": [1, 2, 3],
285+
}
286+
)
287+
expected_error_msg = "The following required columns are missing from the input data: ['T', 'recency']"
288+
289+
with pytest.raises(ValueError, match=expected_error_msg):
290+
CLVModel._validate_cols(data=data, required_cols=required)

0 commit comments

Comments
 (0)