Skip to content

Commit 74697f4

Browse files
committed
Merge branch 'dev' of github.com:maks-sh/scikit-uplift into dev
2 parents dad9ce4 + 82fcec8 commit 74697f4

File tree

8 files changed

+151
-66
lines changed

8 files changed

+151
-66
lines changed

Readme.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
:align: center
3333
:alt: scikit-uplift: uplift modeling in scikit-learn style in python
3434

35+
.. |Contribs| image:: https://contrib.rocks/image?repo=maks-sh/scikit-uplift
36+
:target: https://github.com/maks-sh/scikit-uplift/graphs/contributors
37+
:alt: Contributors
3538

3639
scikit-uplift
3740
===============
@@ -199,6 +202,10 @@ We welcome new contributors of all experience levels.
199202
- Please see our `Contributing Guide <https://www.uplift-modeling.com/en/latest/contributing.html>`_ for more details.
200203
- By participating in this project, you agree to abide by its `Code of Conduct <https://github.com/maks-sh/scikit-uplift/blob/master/.github/CODE_OF_CONDUCT.md>`__.
201204

205+
Thanks to all our contributors!
206+
207+
|Contribs|
208+
202209
If you have any questions, please contact us at [email protected]
203210

204211
Important links
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
********************************************
2+
`sklift.models <./>`_.ClassTransformationReg
3+
********************************************
4+
5+
.. autoclass:: sklift.models.models.ClassTransformationReg
6+
:members:

docs/api/models/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ See :ref:`Models <models>` section of the User Guide for further details.
99

1010
./SoloModel
1111
./ClassTransformation
12+
./ClassTransformationReg
1213
./TwoModels

docs/index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
.. |Open In Colab4| image:: https://colab.research.google.com/assets/colab-badge.svg
55
.. _Open In Colab4: https://colab.research.google.com/github/maks-sh/scikit-uplift/blob/master/notebooks/uplift_model_selection_tutorial.ipynb
66

7+
.. |Contribs| image:: https://contrib.rocks/image?repo=maks-sh/scikit-uplift
8+
:target: https://github.com/maks-sh/scikit-uplift/graphs/contributors
9+
:alt: Contributors
10+
711
**************
812
scikit-uplift
913
**************
@@ -76,6 +80,10 @@ Sklift is being actively maintained and welcomes new contributors of all experie
7680
- Please see our `Contributing Guide <https://www.uplift-modeling.com/en/latest/contributing.html>`_ for more details.
7781
- By participating in this project, you agree to abide by its `Code of Conduct <https://github.com/maks-sh/scikit-uplift/blob/master/.github/CODE_OF_CONDUCT.md>`__.
7882

83+
Thanks to all our contributors!
84+
85+
|Contribs|
86+
7987
If you have any questions, please contact us at [email protected]
8088

8189
.. toctree::

docs/user_guide/models/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ Models
1616
./classification
1717
./solo_model
1818
./revert_label
19+
./transformed_outcome
1920
./two_models
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
.. _ClassTransformationReg:
2+
3+
********************
4+
Transformed Outcome
5+
********************
6+
7+
Let's redefine target variable, which indicates that treatment make some impact on target or
8+
did target is negative without treatment:
9+
10+
.. math::
11+
Z = Y * \frac{(W - p)}{(p * (1 - p))}
12+
13+
* :math:`Y` - target vector,
14+
* :math:`W` - vector of binary communication flags, and
15+
* :math:`p` is a *propensity score* (the probabilty that each :math:`y_i` is assigned to the treatment group.).
16+
17+
It is important to note here that it is possible to estimate :math:`p` as the proportion of objects with :math:`W = 1`
18+
in the sample. Or use the method from [2], in which it is proposed to evaluate math:`p` as a function of :math:`X` by
19+
training the classifier on the available data :math:`X = x`, and taking the communication flag vector math:`W` as
20+
the target variable.
21+
22+
.. image:: https://habrastorage.org/r/w1560/webt/35/d2/z_/35d2z_-3yhyqhwtw-mt-npws6xk.png
23+
:align: center
24+
:alt: Transformation of the target in Transformed Outcome approach
25+
26+
After applying the formula, we get a new target variable :math:`Z_i` and can train a regression model with the error
27+
functional :math:`MSE= \frac{1}{n}\sum_{i=0}^{n} (Z_i - \hat{Z_i})^2`. Since it is precisely when using MSE that the
28+
predictions of the model are the conditional mathematical expectation of the target variable.
29+
30+
It can be proved that the conditional expectation of the transformed target :math:`Z_i` is the desired causal effect:
31+
32+
.. math::
33+
E[Z_i| X_i = x] = Y_i^1 - Y_i^0 = \tau_i
34+
35+
.. hint::
36+
In sklift this approach corresponds to the :class:`.ClassTransformationReg` class.
37+
38+
References
39+
==========
40+
41+
1️⃣ Susan Athey and Guido W Imbens. Machine learning methods for estimating heterogeneouscausal effects. stat, 1050:5, 2015.
42+
43+
2️⃣ P. Richard Hahn, Jared S. Murray, and Carlos Carvalho. Bayesian regression tree models for causal inference: regularization, confounding, and heterogeneous effects. 2019.

sklift/datasets/datasets.py

Lines changed: 70 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import hashlib
12
import os
23
import shutil
3-
import hashlib
44

55
import pandas as pd
66
import requests
@@ -17,7 +17,6 @@ def get_data_dir():
1717
1818
Returns:
1919
string: The path to scikit-uplift data dir.
20-
2120
"""
2221
return os.path.join(os.path.expanduser("~"), "scikit-uplift-data")
2322

@@ -27,13 +26,12 @@ def _create_data_dir(path):
2726
2827
Args:
2928
path (str): The path to scikit-uplift data dir.
30-
3129
"""
3230
if not os.path.isdir(path):
3331
os.makedirs(path)
3432

3533

36-
def _download(url, dest_path, content_length_header_key='Content-Length'):
34+
def _download(url, dest_path, content_length_header_key='Content-Length', desc=None):
3735
"""Download the file from url and save it locally.
3836
3937
Args:
@@ -48,7 +46,7 @@ def _download(url, dest_path, content_length_header_key='Content-Length'):
4846

4947
with open(dest_path, "wb") as fd:
5048
total_size_in_bytes = int(req.headers.get(content_length_header_key, 0))
51-
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
49+
progress_bar = tqdm(desc=desc, total=total_size_in_bytes, unit='iB', unit_scale=True)
5250
for chunk in req.iter_content(chunk_size=2 ** 20):
5351
progress_bar.update(len(chunk))
5452
fd.write(chunk)
@@ -57,7 +55,7 @@ def _download(url, dest_path, content_length_header_key='Content-Length'):
5755

5856

5957
def _get_data(data_home, url, dest_subdir, dest_filename, download_if_missing,
60-
content_length_header_key='Content-Length'):
58+
content_length_header_key='Content-Length', desc=None):
6159
"""Return the path to the dataset.
6260
6361
Args:
@@ -72,7 +70,6 @@ def _get_data(data_home, url, dest_subdir, dest_filename, download_if_missing,
7270
7371
Returns:
7472
string: The path to the dataset.
75-
7673
"""
7774
if data_home is None:
7875
if dest_subdir is None:
@@ -91,13 +88,19 @@ def _get_data(data_home, url, dest_subdir, dest_filename, download_if_missing,
9188

9289
if not os.path.isfile(dest_path):
9390
if download_if_missing:
94-
_download(url, dest_path, content_length_header_key)
91+
_download(url, dest_path, content_length_header_key, desc)
9592
else:
9693
raise IOError("Dataset missing")
9794
return dest_path
9895

99-
def _get_file_hash(csv_path):
100-
with open(csv_path, 'rb') as file_to_check:
96+
97+
def _get_file_hash(path):
98+
"""Сompute the hash value for a file by using md5 algorithm.
99+
100+
Args:
101+
path (str): The path to file
102+
"""
103+
with open(path, 'rb') as file_to_check:
101104
data = file_to_check.read()
102105
return hashlib.md5(data).hexdigest()
103106

@@ -107,7 +110,6 @@ def clear_data_dir(path=None):
107110
108111
Args:
109112
path (str): The path to scikit-uplift data dir
110-
111113
"""
112114
if path is None:
113115
path = get_data_dir()
@@ -175,20 +177,21 @@ def fetch_lenta(data_home=None, dest_subdir=None, download_if_missing=True, retu
175177
176178
:func:`.fetch_megafon`: Load and return the MegaFon Uplift Competition dataset (classification).
177179
"""
178-
179180
lenta_metadata = {
181+
'desc': 'Lenta dataset',
180182
'url': 'https://sklift.s3.eu-west-2.amazonaws.com/lenta_dataset.csv.gz',
181183
'hash': '6ab28ff0989ed8b8647f530e2e86452f'
182184
}
183185

184186
filename = lenta_metadata['url'].split('/')[-1]
185187
csv_path = _get_data(data_home=data_home, url=lenta_metadata['url'], dest_subdir=dest_subdir,
186188
dest_filename=filename,
187-
download_if_missing=download_if_missing)
189+
download_if_missing=download_if_missing,
190+
desc=lenta_metadata['desc'])
188191

189192
if _get_file_hash(csv_path) != lenta_metadata['hash']:
190-
raise ValueError(f"The {filename} file is broken,\
191-
please clean the directory with the clean_data_dir function, and run the function again")
193+
raise ValueError(f"The {filename} file is broken, please clean the directory "
194+
f"with the clean_data_dir() function, and run the function again")
192195

193196
target_col = 'response_att'
194197
treatment_col = 'group'
@@ -276,23 +279,26 @@ def fetch_x5(data_home=None, dest_subdir=None, download_if_missing=True):
276279
277280
:func:`.fetch_megafon`: Load and return the MegaFon Uplift Competition dataset (classification).
278281
"""
279-
280282
x5_metadata = {
283+
'desc_train': 'Part 1: X5 train',
284+
'desc_clients': 'Part 2: X5 clients',
285+
'desc_purchases': 'Part 3: X5 purchases',
281286
'url_train': 'https://sklift.s3.eu-west-2.amazonaws.com/uplift_train.csv.gz',
282287
'url_clients': 'https://sklift.s3.eu-west-2.amazonaws.com/clients.csv.gz',
283288
'url_purchases': 'https://sklift.s3.eu-west-2.amazonaws.com/purchases.csv.gz',
284-
'uplift_hash': '2720bbb659daa9e0989b2777b6a42d19',
285-
'clients_hash': 'b9cdeb2806b732771de03e819b3354c5',
286-
'purchases_hash': '48d2de13428e24e8b61d66fef02957a8'
289+
'hash_train': '2720bbb659daa9e0989b2777b6a42d19',
290+
'hash_clients': 'b9cdeb2806b732771de03e819b3354c5',
291+
'hash_purchases': '48d2de13428e24e8b61d66fef02957a8'
287292
}
288293
file_train = x5_metadata['url_train'].split('/')[-1]
289294
csv_train_path = _get_data(data_home=data_home, url=x5_metadata['url_train'], dest_subdir=dest_subdir,
290295
dest_filename=file_train,
291-
download_if_missing=download_if_missing)
296+
download_if_missing=download_if_missing,
297+
desc=x5_metadata['desc_train'])
292298

293-
if _get_file_hash(csv_train_path) != x5_metadata['uplift_hash']:
294-
raise ValueError(f"The {file_train} file is broken,\
295-
please clean the directory with the clean_data_dir function, and run the function again")
299+
if _get_file_hash(csv_train_path) != x5_metadata['hash_train']:
300+
raise ValueError(f"The {file_train} file is broken, please clean the directory "
301+
f"with the clean_data_dir() function, and run the function again")
296302

297303
train = pd.read_csv(csv_train_path)
298304
train_features = list(train.columns)
@@ -307,24 +313,26 @@ def fetch_x5(data_home=None, dest_subdir=None, download_if_missing=True):
307313
file_clients = x5_metadata['url_clients'].split('/')[-1]
308314
csv_clients_path = _get_data(data_home=data_home, url=x5_metadata['url_clients'], dest_subdir=dest_subdir,
309315
dest_filename=file_clients,
310-
download_if_missing=download_if_missing)
316+
download_if_missing=download_if_missing,
317+
desc=x5_metadata['desc_clients'])
311318

312-
if _get_file_hash(csv_clients_path) != x5_metadata['clients_hash']:
313-
raise ValueError(f"The {file_clients} file is broken,\
314-
please clean the directory with the clean_data_dir function, and run the function again")
319+
if _get_file_hash(csv_clients_path) != x5_metadata['hash_clients']:
320+
raise ValueError(f"The {file_clients} file is broken, please clean the directory "
321+
f"with the clean_data_dir() function, and run the function again")
315322

316323
clients = pd.read_csv(csv_clients_path)
317324
clients_features = list(clients.columns)
318325

319326
file_purchases = x5_metadata['url_purchases'].split('/')[-1]
320327
csv_purchases_path = _get_data(data_home=data_home, url=x5_metadata['url_purchases'], dest_subdir=dest_subdir,
321328
dest_filename=file_purchases,
322-
download_if_missing=download_if_missing)
329+
download_if_missing=download_if_missing,
330+
desc=x5_metadata['desc_purchases'])
323331

324-
if _get_file_hash(csv_clients_path) != x5_metadata['purchases_hash']:
325-
raise ValueError(f"The {file_purchases} file is broken,\
326-
please clean the directory with the clean_data_dir function, and run the function again")
327-
332+
if _get_file_hash(csv_purchases_path) != x5_metadata['hash_purchases']:
333+
raise ValueError(f"The {file_purchases} file is broken, please clean the directory "
334+
f"with the clean_data_dir() function, and run the function again")
335+
328336
purchases = pd.read_csv(csv_purchases_path)
329337
purchases_features = list(purchases.columns)
330338

@@ -426,26 +434,28 @@ def fetch_criteo(target_col='visit', treatment_col='treatment', data_home=None,
426434
raise ValueError(f"The target_col must be an element of {target_cols + ['all']}. "
427435
f"Got value target_col={target_col}.")
428436

429-
criteo_metadata = {
430-
'url': '',
431-
'criteo_hash': ''
432-
}
433-
434437
if percent10:
435-
criteo_metadata['url'] = 'https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo10.csv.gz'
436-
criteo_metadata['criteo_hash'] = 'fe159bcee2cea57548e48eb2d7d5d00c'
438+
criteo_metadata = {
439+
'desc': 'Criteo dataset (10 percent)',
440+
'url': 'https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo10.csv.gz',
441+
'hash': 'fe159bcee2cea57548e48eb2d7d5d00c'
442+
}
437443
else:
438-
criteo_metadata['url'] = "https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo.csv.gz"
439-
criteo_metadata['criteo_hash'] = 'd2236769ef69e9be52556110102911ec'
444+
criteo_metadata = {
445+
'desc': 'Criteo dataset',
446+
'url': 'https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo.csv.gz',
447+
'hash': 'd2236769ef69e9be52556110102911ec'
448+
}
440449

441450
filename = criteo_metadata['url'].split('/')[-1]
442451
csv_path = _get_data(data_home=data_home, url=criteo_metadata['url'], dest_subdir=dest_subdir,
443452
dest_filename=filename,
444-
download_if_missing=download_if_missing)
453+
download_if_missing=download_if_missing,
454+
desc=criteo_metadata['desc'])
445455

446-
if _get_file_hash(csv_path) != criteo_metadata['criteo_hash']:
447-
raise ValueError(f"The {filename} file is broken,\
448-
please clean the directory with the clean_data_dir function, and run the function again")
456+
if _get_file_hash(csv_path) != criteo_metadata['hash']:
457+
raise ValueError(f"The {filename} file is broken, please clean the directory "
458+
f"with the clean_data_dir() function, and run the function again")
449459

450460
dtypes = {
451461
'exposure': 'Int8',
@@ -544,18 +554,20 @@ def fetch_hillstrom(target_col='visit', data_home=None, dest_subdir=None, downlo
544554
f"Got value target_col={target_col}.")
545555

546556
hillstrom_metadata = {
557+
'desc': 'Hillstrom dataset',
547558
'url': 'https://hillstorm1.s3.us-east-2.amazonaws.com/hillstorm_no_indices.csv.gz',
548-
'hillstrom_hash': 'a68a81291f53a14f4e29002629803ba3'
559+
'hash': 'a68a81291f53a14f4e29002629803ba3'
549560
}
550561

551562
filename = hillstrom_metadata['url'].split('/')[-1]
552563
csv_path = _get_data(data_home=data_home, url=hillstrom_metadata['url'], dest_subdir=dest_subdir,
553564
dest_filename=filename,
554-
download_if_missing=download_if_missing)
565+
download_if_missing=download_if_missing,
566+
desc=hillstrom_metadata['desc'])
555567

556-
if _get_file_hash(csv_path) != hillstrom_metadata['hillstrom_hash']:
557-
raise ValueError(f"The {filename} file is broken,\
558-
please clean the directory with the clean_data_dir function, and run the function again")
568+
if _get_file_hash(csv_path) != hillstrom_metadata['hash']:
569+
raise ValueError(f"The {filename} file is broken, please clean the directory "
570+
f"with the clean_data_dir() function, and run the function again")
559571

560572
treatment_col = 'segment'
561573

@@ -634,21 +646,22 @@ def fetch_megafon(data_home=None, dest_subdir=None, download_if_missing=True,
634646
:func:`.fetch_criteo`: Load and return the Criteo Uplift Prediction Dataset (classification).
635647
636648
:func:`.fetch_hillstrom`: Load and return Kevin Hillstrom Dataset MineThatData (classification or regression).
637-
638649
"""
639650
megafon_metadata = {
651+
'desc': 'Megafon dataset',
640652
'url': 'https://sklift.s3.eu-west-2.amazonaws.com/megafon_dataset.csv.gz',
641-
'megafon_hash': 'ee8d45a343d4d2cf90bb756c93959ecd'
653+
'hash': 'ee8d45a343d4d2cf90bb756c93959ecd'
642654
}
643655

644656
filename = megafon_metadata['url'].split('/')[-1]
645657
csv_path = _get_data(data_home=data_home, url=megafon_metadata['url'], dest_subdir=dest_subdir,
646-
dest_filename=filename,
647-
download_if_missing=download_if_missing)
658+
dest_filename=filename,
659+
download_if_missing=download_if_missing,
660+
desc=megafon_metadata['desc'])
648661

649-
if _get_file_hash(csv_path) != megafon_metadata['megafon_hash']:
650-
raise ValueError(f"The {filename} file is broken,\
651-
please clean the directory with the clean_data_dir function, and run the function again")
662+
if _get_file_hash(csv_path) != megafon_metadata['hash']:
663+
raise ValueError(f"The {filename} file is broken, please clean the directory "
664+
f"with the clean_data_dir() function, and run the function again")
652665

653666
train = pd.read_csv(csv_path)
654667

0 commit comments

Comments
 (0)