Skip to content

Commit a140689

Browse files
committed
Merge branch 'dev' of github.com:maks-sh/scikit-uplift into dev
2 parents b1fecae + 29cbb3a commit a140689

File tree

9 files changed

+289
-32
lines changed

9 files changed

+289
-32
lines changed

docs/api/datasets/fetch_criteo.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,17 @@
66

77
.. autofunction:: sklift.datasets.datasets.fetch_criteo
88

9+
About the company
10+
##################
11+
12+
Criteo is an advertising company that provides online display advertisements.
13+
The company was founded and is headquartered in Paris, France. Criteo's product is a form of display advertising,
14+
which displays interactive banner advertisements, generated based on the online browsing preferences and behaviour for each customer.
15+
The solution operates on a pay per click/cost per click (CPC) basis.
16+
17+
.. figure:: https://upload.wikimedia.org/wikipedia/commons/d/d2/Criteo_logo21.svg
18+
19+
Link to the company's website: https://www.criteo.com/
20+
21+
922
.. include:: ../../../sklift/datasets/descr/criteo.rst

docs/api/datasets/fetch_hillstrom.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,12 @@
66

77
.. autofunction:: sklift.datasets.datasets.fetch_hillstrom
88

9+
About the company
10+
##################
11+
12+
The dataset was provided by Kevin Hillstorm.
13+
Kevin is President of MineThatData, a consultancy that helps CEOs understand the complex relationship between Customers, Advertising, Products, Brands, and Channels.
14+
15+
Link to the blog website: https://blog.minethatdata.com/
16+
917
.. include:: ../../../sklift/datasets/descr/hillstrom.rst

docs/api/datasets/fetch_lenta.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,14 @@
66

77
.. autofunction:: sklift.datasets.datasets.fetch_lenta
88

9+
About the company
10+
##################
11+
12+
Lenta (Russian: Лентa) is a Russian super - and hypermarket chain. With 149 locations across the country,
13+
it is one of Russia's largest retail chains in addition to being the country's second largest hypermarket chain.
14+
15+
.. figure:: https://upload.wikimedia.org/wikipedia/commons/7/73/Lenta_logo.svg
16+
17+
Link to the company's website: https://www.lenta.com/
18+
919
.. include:: ../../../sklift/datasets/descr/lenta.rst

docs/api/datasets/fetch_megafon.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,14 @@
66

77
.. autofunction:: sklift.datasets.datasets.fetch_megafon
88

9+
About the company
10+
##################
11+
12+
MegaFon (Russian: МегаФон), previously known as North-West GSM, is the second largest mobile phone operator and the third largest telecom operator in Russia.
13+
It works in the GSM, UMTS and LTE standard. As of June 2012, the company serves 62.1 million subscribers in Russia and 1.6 million in Tajikistan. It is headquartered in Moscow.
14+
15+
.. figure:: https://upload.wikimedia.org/wikipedia/commons/9/9e/MegaFon_logo.svg
16+
17+
Link to the company's website: https://megafon.ru/
18+
919
.. include:: ../../../sklift/datasets/descr/megafon.rst

docs/api/datasets/fetch_x5.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,16 @@
66

77
.. autofunction:: sklift.datasets.datasets.fetch_x5
88

9+
About the company
10+
##################
11+
12+
X5 Group is a leading Russian food retailer.
13+
The Company operates several retail formats: proximity stores under the Pyaterochka brand,
14+
supermarkets under the Perekrestok brand and hypermarkets under the Karusel brand, as well as the Perekrestok.ru online market,
15+
the 5Post parcel and Dostavka.Pyaterochka and Perekrestok. Bystro food delivery services.
16+
17+
.. figure:: https://upload.wikimedia.org/wikipedia/en/8/83/X5_Retail_Group_logo_2015.png
18+
19+
Link to the company's website: https://www.x5.ru/
20+
921
.. include:: ../../../sklift/datasets/descr/x5.rst

sklift/tests/test_datasets.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,23 @@ def test_fetch_lenta(lenta_dataset):
3434
assert data.target.shape == lenta_dataset['target.shape']
3535
assert data.treatment.shape == lenta_dataset['treatment.shape']
3636

37-
38-
# @pytest.fixture
39-
# def x5_dataset() -> dict:
40-
# data = {'keys': ['data', 'target', 'treatment', 'DESCR', 'feature_names', 'target_name', 'treatment_name'],
37+
#@pytest.fixture
38+
#def x5_dataset() -> dict:
39+
# data = {'keys': ['data', 'target', 'treatment', 'DESCR', 'feature_names', 'target_name', 'treatment_name'],
4140
# 'data.keys': ['clients', 'train', 'purchases'], 'clients.shape': (400162, 5),
42-
# 'train.shape': (200039, 1), 'target.shape': (200039,), 'treatment.shape': (200039,)}
43-
# return data
44-
#
41+
# 'train.shape': (200039, 1), 'target.shape': (200039,), 'treatment.shape': (200039,)}
42+
# return data
43+
4544
#
46-
# def test_fetch_x5(x5_dataset):
47-
# data = fetch_x5()
48-
# assert isinstance(data, sklearn.utils.Bunch)
49-
# assert set(data.keys()) == set(x5_dataset['keys'])
50-
# assert set(data.data.keys()) == set(x5_dataset['data.keys'])
51-
# assert data.data.clients.shape == x5_dataset['clients.shape']
52-
# assert data.data.train.shape == x5_dataset['train.shape']
53-
# assert data.target.shape == x5_dataset['target.shape']
54-
# assert data.treatment.shape == x5_dataset['treatment.shape']
45+
#def test_fetch_x5(x5_dataset):
46+
# data = fetch_x5()
47+
# assert isinstance(data, sklearn.utils.Bunch)
48+
# assert set(data.keys()) == set(x5_dataset['keys'])
49+
# assert set(data.data.keys()) == set(x5_dataset['data.keys'])
50+
# assert data.data.clients.shape == x5_dataset['clients.shape']
51+
# assert data.data.train.shape == x5_dataset['train.shape']
52+
# assert data.target.shape == x5_dataset['target.shape']
53+
# assert data.treatment.shape == x5_dataset['treatment.shape']
5554

5655

5756
@pytest.fixture
@@ -85,6 +84,14 @@ def test_fetch_criteo10(
8584
assert data.target.shape == target_shape
8685
assert data.treatment.shape == treatment_shape
8786

87+
@pytest.mark.parametrize(
88+
'target_col, treatment_col',
89+
[('visit','new_trmnt'), ('new_target','treatment')]
90+
)
91+
def test_fetch_criteo_errors(target_col, treatment_col):
92+
with pytest.raises(ValueError):
93+
fetch_criteo(target_col=target_col, treatment_col=treatment_col)
94+
8895

8996
@pytest.fixture
9097
def hillstrom_dataset() -> dict:
@@ -111,6 +118,10 @@ def test_fetch_hillstrom(
111118
assert data.target.shape == target_shape
112119
assert data.treatment.shape == hillstrom_dataset['treatment.shape']
113120

121+
def test_fetch_hillstrom_error():
122+
with pytest.raises(ValueError):
123+
fetch_hillstrom(target_col='new_target')
124+
114125

115126
@pytest.fixture
116127
def megafon_dataset() -> dict:

sklift/tests/test_metrics.py

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
from sklearn.utils._testing import assert_array_almost_equal
99

10+
from ..metrics import make_uplift_scorer
1011
from ..metrics import uplift_curve, uplift_auc_score, perfect_uplift_curve
1112
from ..metrics import qini_curve, qini_auc_score, perfect_qini_curve
1213
from ..metrics import (uplift_at_k, response_rate_by_percentile,
13-
weighted_average_uplift, uplift_by_percentile, treatment_balance_curve)
14+
weighted_average_uplift, uplift_by_percentile, treatment_balance_curve, average_squared_deviation)
1415

1516

1617
def make_predictions(binary):
@@ -221,6 +222,12 @@ def test_perfect_qini_curve_hard():
221222

222223
assert_array_almost_equal(x_actual, np.array([0., 0., 3.]))
223224
assert_array_almost_equal(y_actual, np.array([0.0, 0.0, 0.0]))
225+
226+
def test_perfect_qini_curve_error():
227+
y_true, uplift, treatment = make_predictions(binary=True)
228+
with pytest.raises(TypeError):
229+
perfect_qini_curve(y_true, treatment, negative_effect=5)
230+
224231

225232

226233
def test_qini_auc_score():
@@ -255,11 +262,33 @@ def test_qini_auc_score():
255262
treatment = [1, 0, 1]
256263
assert_array_almost_equal(qini_auc_score(y_true, uplift, treatment), 0.75)
257264

265+
def test_qini_auc_score_error():
266+
y_true = [1, 0]
267+
uplift = [0.1, 0.3]
268+
treatment = [0, 1]
269+
with pytest.raises(TypeError):
270+
qini_auc_score(y_true, uplift, treatment, negative_effect=5)
271+
258272

259273
def test_uplift_at_k():
260274
y_true, uplift, treatment = make_predictions(binary=True)
261275

262276
assert_array_almost_equal(uplift_at_k(y_true, uplift, treatment, strategy='by_group', k=1), np.array([0.]))
277+
#assert_array_almost_equal(uplift_at_k(y_true, uplift, treatment, strategy='overall', k=2), np.array([0.]))
278+
279+
@pytest.mark.parametrize(
280+
"strategy, k",
281+
[
282+
('new_strategy', 1),
283+
('by_group', -0.5),
284+
('by_group', '1'),
285+
('by_group', 2)
286+
]
287+
)
288+
def test_uplift_at_k_errors(strategy, k):
289+
y_true, uplift, treatment = make_predictions(binary=True)
290+
with pytest.raises(ValueError):
291+
uplift_at_k(y_true, uplift, treatment, strategy, k)
263292

264293

265294
@pytest.mark.parametrize(
@@ -277,6 +306,19 @@ def test_response_rate_by_percentile(strategy, group, response_rate):
277306
assert_array_almost_equal(response_rate_by_percentile(y_true, uplift, treatment, group, strategy, bins=1),
278307
response_rate)
279308

309+
@pytest.mark.parametrize(
310+
"strategy, group, bins",
311+
[
312+
('new_strategy', 'control', 1),
313+
('by_group', 'ctrl', 1),
314+
('by_group', 'control', 0.5),
315+
('by_group', 'control', 9999)
316+
]
317+
)
318+
def test_response_rate_by_percentile_errors(strategy, group, bins):
319+
y_true, uplift, treatment = make_predictions(binary=True)
320+
with pytest.raises(ValueError):
321+
response_rate_by_percentile(y_true, uplift, treatment, group=group, strategy=strategy, bins=bins)
280322

281323
@pytest.mark.parametrize(
282324
"strategy, weighted_average",
@@ -289,7 +331,21 @@ def test_weighted_average_uplift(strategy, weighted_average):
289331
y_true, uplift, treatment = make_predictions(binary=True)
290332

291333
assert_array_almost_equal(weighted_average_uplift(y_true, uplift, treatment, strategy, bins=1), weighted_average)
334+
292335

336+
@pytest.mark.parametrize(
337+
"strategy, bins",
338+
[
339+
('new_strategy', 1),
340+
('by_group', 0.5),
341+
('by_group', 9999)
342+
]
343+
)
344+
def test_weighted_average_uplift_errors(strategy, bins):
345+
y_true, uplift, treatment = make_predictions(binary=True)
346+
with pytest.raises(ValueError):
347+
weighted_average_uplift(y_true, uplift, treatment, strategy=strategy, bins=bins)
348+
293349

294350
@pytest.mark.parametrize(
295351
"strategy, bins, std, total, string_percentiles, data",
@@ -307,11 +363,68 @@ def test_uplift_by_percentile(strategy, bins, std, total, string_percentiles, da
307363

308364
assert_array_almost_equal(
309365
uplift_by_percentile(y_true, uplift, treatment, strategy, bins, std, total, string_percentiles), data)
366+
367+
@pytest.mark.parametrize(
368+
"strategy, bins, std, total, string_percentiles",
369+
[
370+
('new_strategy', 1, True, True, True),
371+
('by_group', 0.5, True, True, True),
372+
('by_group', 9999, True, True, True),
373+
('by_group', 1, 2, True, True),
374+
('by_group', 1, True, True, 2),
375+
('by_group', 1, True, 2, True)
376+
]
377+
)
378+
def test_uplift_by_percentile_errors(strategy, bins, std, total, string_percentiles):
379+
y_true, uplift, treatment = make_predictions(binary=True)
380+
with pytest.raises(ValueError):
381+
uplift_by_percentile(y_true, uplift, treatment, strategy, bins, std, total, string_percentiles)
310382

311383

312384
def test_treatment_balance_curve():
313385
y_true, uplift, treatment = make_predictions(binary=True)
314386

315387
idx, balance = treatment_balance_curve(uplift, treatment, winsize=2)
316388
assert_array_almost_equal(idx, np.array([1., 100.]))
317-
assert_array_almost_equal(balance, np.array([1., 0.5]))
389+
assert_array_almost_equal(balance, np.array([1., 0.5]))
390+
391+
@pytest.mark.parametrize(
392+
"strategy",
393+
[
394+
('overall'),
395+
('by_group')
396+
]
397+
)
398+
def test_average_squared_deviation(strategy):
399+
y_true, uplift, treatment = make_predictions(binary=True)
400+
assert (average_squared_deviation(y_true, uplift, treatment, y_true, uplift, treatment, strategy, bins=1) == 0)
401+
402+
@pytest.mark.parametrize(
403+
"strategy, bins",
404+
[
405+
('new_strategy', 1),
406+
('by_group', 0.5),
407+
('by_group', 9999)
408+
]
409+
)
410+
def test_average_squared_deviation_errors(strategy, bins):
411+
y_true, uplift, treatment = make_predictions(binary=True)
412+
with pytest.raises(ValueError):
413+
average_squared_deviation(y_true, uplift, treatment, y_true, uplift, treatment, strategy=strategy, bins=bins)
414+
415+
def test_metric_name_error():
416+
with pytest.raises(ValueError):
417+
make_uplift_scorer('new_scorer', [0, 1])
418+
419+
def test_make_scorer_error():
420+
with pytest.raises(TypeError):
421+
make_uplift_scorer('qini_auc_score', [])
422+
423+
424+
425+
426+
427+
428+
429+
430+

sklift/tests/test_models.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import numpy as np
23
from sklearn.linear_model import LogisticRegression, LinearRegression
34
from sklearn.pipeline import Pipeline
45
from sklearn.preprocessing import StandardScaler
@@ -43,3 +44,50 @@ def test_shape_regression(model, random_xy_dataset_regr):
4344
assert model.fit(X, y, treat).predict(X).shape[0] == y.shape[0]
4445
pipe = Pipeline(steps=[("scaler", StandardScaler()), ("clf", model)])
4546
assert pipe.fit(X, y, clf__treatment=treat).predict(X).shape[0] == y.shape[0]
47+
48+
@pytest.mark.parametrize(
49+
"model",
50+
[
51+
SoloModel(LogisticRegression(), method='dummy'),
52+
SoloModel(LogisticRegression(), method='treatment_interaction'),
53+
]
54+
)
55+
def test_solomodel_fit_error(model):
56+
X, y, treatment = [[1., 0., 0.],[1., 0., 0.],[1., 0., 0.]], [1., 2., 3.], [0., 1., 0.]
57+
with pytest.raises(TypeError):
58+
model.fit(X, y, treatment)
59+
60+
@pytest.mark.parametrize(
61+
"model",
62+
[
63+
SoloModel(LogisticRegression(), method='dummy'),
64+
SoloModel(LogisticRegression(), method='treatment_interaction'),
65+
]
66+
)
67+
def test_solomodel_pred_error(model):
68+
X_train, y_train, treat_train = (np.array([[5.1, 3.5, 1.4, 0.2], [4.9, 3.0, 1.4, 0.2], [4.7, 3.2, 1.3, 0.2]]),
69+
np.array([0.0, 0.0, 1.0]), np.array([0.0, 1.0, 1.0]))
70+
model.fit(X_train, y_train, treat_train)
71+
with pytest.raises(TypeError):
72+
model.predict(1)
73+
74+
@pytest.mark.parametrize("method", ['method'])
75+
def test_solomodel_method_error(method):
76+
with pytest.raises(ValueError):
77+
SoloModel(LogisticRegression(), method=method)
78+
79+
def test_classtransformation_fit_error():
80+
X, y, treatment = [[1., 0., 0.],[1., 0., 0.],[1., 0., 0.]], [1., 2., 3.], [0., 1., 0.]
81+
with pytest.raises(ValueError):
82+
ClassTransformation(LogisticRegression()).fit(X, y, treatment)
83+
84+
@pytest.mark.parametrize("method", ['method'])
85+
def test_twomodels_method_error(method):
86+
with pytest.raises(ValueError):
87+
TwoModels(LinearRegression(), LinearRegression(), method=method)
88+
89+
def test_same_estimator_error():
90+
est = LinearRegression()
91+
with pytest.raises(ValueError):
92+
TwoModels(est, est)
93+

0 commit comments

Comments
 (0)