Skip to content

Commit aa5fa1d

Browse files
authored
smaller test dataset for bgbb (#1039)
1 parent d5fa543 commit aa5fa1d

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

tests/clv/models/test_beta_geo_beta_binom.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,10 @@ def setup_class(cls):
4040
cls.gamma_true = 0.6567
4141

4242
# Use Quickstart dataset (the CDNOW_sample research data) for testing
43-
test_data = pd.read_csv("data/bgbb_donations.csv")
43+
cls.data = pd.read_csv("data/bgbb_donations.csv")
4444

45-
cls.data = test_data
46-
# cls.customer_id = test_data["customer_id"]
47-
# cls.frequency = test_data["frequency"]
48-
# cls.recency = test_data["recency"]
49-
# cls.T = test_data["T"]
45+
# sample from full dataset for tests involving model fits
46+
cls.sample_data = cls.data.sample(n=1000, random_state=45)
5047

5148
# take sample of all unique recency/frequency/T combinations to test predictive methods
5249
test_customer_ids = [
@@ -74,8 +71,8 @@ def setup_class(cls):
7471
11103,
7572
]
7673

77-
cls.sample_data = test_data.query("customer_id.isin(@test_customer_ids)")
78-
cls.sample_data_N = len(test_customer_ids)
74+
cls.pred_data = cls.data.query("customer_id.isin(@test_customer_ids)")
75+
cls.pred_data_N = len(test_customer_ids)
7976

8077
# Instantiate model with CDNOW data for testing
8178
cls.model = BetaGeoBetaBinomModel(cls.data)
@@ -278,13 +275,16 @@ def test_model_repr(self, custom_config):
278275
@pytest.mark.parametrize(
279276
"fit_method, rtol",
280277
[
281-
("mcmc", 0.1),
278+
(
279+
"mcmc",
280+
0.3,
281+
), # higher rtol required for sample_data; within .1 tolerance for full dataset;
282282
("map", 0.2),
283283
],
284284
)
285285
def test_model_convergence(self, fit_method, rtol, model_config):
286286
model = BetaGeoBetaBinomModel(
287-
data=self.data,
287+
data=self.sample_data,
288288
model_config=model_config,
289289
)
290290
model.build_model()
@@ -307,7 +307,7 @@ def test_model_convergence(self, fit_method, rtol, model_config):
307307
)
308308

309309
def test_fit_result_without_fit(self, model_config):
310-
model = BetaGeoBetaBinomModel(data=self.data, model_config=model_config)
310+
model = BetaGeoBetaBinomModel(data=self.pred_data, model_config=model_config)
311311
with pytest.raises(RuntimeError, match="The model hasn't been fit yet"):
312312
model.fit_result
313313

@@ -327,20 +327,20 @@ def test_expected_purchases(self, test_t):
327327
true_purchases = (
328328
self.lifetimes_model.conditional_expected_number_of_purchases_up_to_time(
329329
m_periods_in_future=test_t,
330-
frequency=self.sample_data["frequency"],
331-
recency=self.sample_data["recency"],
332-
n_periods=self.sample_data["T"],
330+
frequency=self.pred_data["frequency"],
331+
recency=self.pred_data["recency"],
332+
n_periods=self.pred_data["T"],
333333
)
334334
)
335335

336336
# test parametrization with default data has different dims
337337
est_num_purchases = self.model.expected_purchases(future_t=test_t)
338338
assert est_num_purchases.shape == (self.chains, self.draws, self.N)
339339

340-
data = self.sample_data.assign(future_t=test_t)
340+
data = self.pred_data.assign(future_t=test_t)
341341
est_num_purchases = self.model.expected_purchases(data)
342342

343-
assert est_num_purchases.shape == (self.chains, self.draws, self.sample_data_N)
343+
assert est_num_purchases.shape == (self.chains, self.draws, self.pred_data_N)
344344
assert est_num_purchases.dims == ("chain", "draw", "customer_id")
345345

346346
np.testing.assert_allclose(
@@ -398,33 +398,33 @@ def test_expected_purchases_new_customer(self):
398398
def test_expected_probability_alive(self, test_t):
399399
true_prob_alive = self.lifetimes_model.conditional_probability_alive(
400400
m_periods_in_future=test_t,
401-
frequency=self.sample_data["frequency"],
402-
recency=self.sample_data["recency"],
403-
n_periods=self.sample_data["T"],
401+
frequency=self.pred_data["frequency"],
402+
recency=self.pred_data["recency"],
403+
n_periods=self.pred_data["T"],
404404
)
405405

406406
# test parametrization with default data has different dims
407407
est_prob_alive = self.model.expected_probability_alive(future_t=test_t)
408408
assert est_prob_alive.shape == (self.chains, self.draws, self.N)
409409

410-
sample_data = self.sample_data.assign(future_t=test_t)
411-
est_prob_alive = self.model.expected_probability_alive(sample_data)
410+
pred_data = self.pred_data.assign(future_t=test_t)
411+
est_prob_alive = self.model.expected_probability_alive(pred_data)
412412

413-
assert est_prob_alive.shape == (self.chains, self.draws, self.sample_data_N)
413+
assert est_prob_alive.shape == (self.chains, self.draws, self.pred_data_N)
414414
assert est_prob_alive.dims == ("chain", "draw", "customer_id")
415415
np.testing.assert_allclose(
416416
true_prob_alive,
417417
est_prob_alive.mean(("chain", "draw")),
418418
rtol=0.01,
419419
)
420420

421-
alt_data = self.sample_data.assign(future_t=7.5)
421+
alt_data = self.pred_data.assign(future_t=7.5)
422422
est_prob_alive_t = self.model.expected_probability_alive(alt_data)
423423
assert est_prob_alive.mean() > est_prob_alive_t.mean()
424424

425425
def test_distribution_new_customer(self) -> None:
426426
mock_model = BetaGeoBetaBinomModel(
427-
data=self.data,
427+
data=self.sample_data,
428428
)
429429
mock_model.build_model()
430430
mock_model.idata = az.from_dict(
@@ -444,7 +444,7 @@ def test_distribution_new_customer(self) -> None:
444444
random_seed=rng
445445
)
446446
customer_rec_freq = mock_model.distribution_new_customer_recency_frequency(
447-
self.data, T=self.data["T"], random_seed=rng
447+
self.sample_data, T=self.sample_data["T"], random_seed=rng
448448
)
449449
customer_rec = customer_rec_freq.sel(obs_var="recency")
450450
customer_freq = customer_rec_freq.sel(obs_var="frequency")
@@ -463,7 +463,7 @@ def test_distribution_new_customer(self) -> None:
463463
beta=self.beta_true,
464464
delta=self.delta_true,
465465
gamma=self.gamma_true,
466-
T=self.data["T"],
466+
T=self.sample_data["T"],
467467
),
468468
random_seed=rng,
469469
).T

0 commit comments

Comments
 (0)