Skip to content

Commit 81974ac

Browse files
committed
Merge branch 'dev' of https://github.com/maks-sh/scikit-uplift into dev
2 parents 76a8756 + 7e68075 commit 81974ac

File tree

11 files changed

+296
-188
lines changed

11 files changed

+296
-188
lines changed

.github/workflows/PyPi_upload.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
name: Upload to PyPi
2+
3+
on:
4+
release:
5+
types: [published]
6+
7+
jobs:
8+
deploy:
9+
10+
runs-on: ubuntu-latest
11+
12+
steps:
13+
- uses: actions/checkout@v2
14+
- name: Set up Python
15+
uses: actions/setup-python@v2
16+
with:
17+
python-version: '3.x'
18+
- name: Install dependencies
19+
run: |
20+
python -m pip install --upgrade pip
21+
pip install setuptools wheel twine
22+
- name: Build and publish
23+
env:
24+
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
25+
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
26+
run: |
27+
python setup.py sdist bdist_wheel
28+
twine upload dist/*

.github/workflows/ci-test.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
name: Python package
2+
3+
on:
4+
push:
5+
branches: [ master ]
6+
pull_request:
7+
8+
9+
jobs:
10+
build:
11+
12+
runs-on: ${{ matrix.operating-system }}
13+
strategy:
14+
matrix:
15+
operating-system: [ubuntu-latest, windows-latest, macos-latest]
16+
python-version: [3.6, 3.7, 3.8, 3.9]
17+
fail-fast: false
18+
19+
steps:
20+
- uses: actions/checkout@v2
21+
- name: Set up Python ${{ matrix.python-version }}
22+
uses: actions/setup-python@v2
23+
with:
24+
python-version: ${{ matrix.python-version }}
25+
- name: Install dependencies and lints
26+
run: pip install pytest .[tests]
27+
- name: Run PyTest
28+
run: pytest

docs/conf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,12 @@ def get_version():
5151
"sphinx.ext.mathjax",
5252
"sphinx.ext.napoleon",
5353
"recommonmark",
54-
"sphinx.ext.intersphinx"
54+
"sphinx.ext.intersphinx",
55+
"sphinxcontrib.bibtex"
5556
]
5657

58+
bibtex_bibfiles = ['refs.bib']
59+
5760
master_doc = 'index'
5861

5962
# Add any paths that contain templates here, relative to this directory.

docs/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
sphinx-autobuild
22
sphinx_rtd_theme
3-
recommonmark
3+
recommonmark
4+
sphinxcontrib-bibtex

sklift/datasets/datasets.py

Lines changed: 123 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,38 @@
66

77

88
def get_data_dir():
9-
"""This function returns a directory, which stores the datasets.
9+
"""Return the path of the scikit-uplift data dir.
10+
11+
This folder is used by some large dataset loaders to avoid downloading the data several times.
12+
13+
By default the data dir is set to a folder named ‘scikit_learn_data’ in the user home folder.
1014
1115
Returns:
12-
Full path to a directory, which stores the datasets.
16+
string: The path to scikit-uplift data dir.
1317
1418
"""
1519
return os.path.join(os.path.expanduser("~"), "scikit-uplift-data")
1620

1721

1822
def _create_data_dir(path):
19-
"""This function creates a directory, which stores the datasets.
23+
"""Creates a directory, which stores the datasets.
2024
2125
Args:
22-
path (str): The path to the folder where datasets are stored.
26+
path (str): The path to scikit-uplift data dir.
2327
2428
"""
2529
if not os.path.isdir(path):
2630
os.makedirs(path)
2731

2832

2933
def _download(url, dest_path):
30-
'''Download the file from url and save it localy
31-
34+
"""Download the file from url and save it locally.
35+
3236
Args:
3337
url: URL address, must be a string.
3438
dest_path: Destination of the file.
3539
36-
Returns:
37-
TypeError if URL is not a string.
38-
'''
40+
"""
3941
if isinstance(url, str):
4042
req = requests.get(url, stream=True)
4143
req.raise_for_status()
@@ -51,14 +53,16 @@ def _get_data(data_home, url, dest_subdir, dest_filename, download_if_missing):
5153
"""Return the path to the dataset.
5254
5355
Args:
54-
data_home (str, unicode): The path to the folder where datasets are stored.
56+
data_home (str, unicode): The path to scikit-uplift data dir.
5557
url (str or unicode): The URL to the dataset.
5658
dest_subdir (str or unicode): The name of the folder in which the dataset is stored.
5759
dest_filename (str): The name of the dataset.
58-
download_if_missing (bool): Flag if dataset is missing.
60+
download_if_missing (bool): If False, raise a IOError if the data is not locally available instead of
61+
trying to download the data from the source site.
5962
6063
Returns:
61-
The path to the dataset.
64+
string: The path to the dataset.
65+
6266
"""
6367
if data_home is None:
6468
if dest_subdir is None:
@@ -84,43 +88,59 @@ def _get_data(data_home, url, dest_subdir, dest_filename, download_if_missing):
8488

8589

8690
def clear_data_dir(path=None):
87-
"""This function deletes the file.
91+
"""Delete all the content of the data home cache.
8892
8993
Args:
90-
path (str): File path. By default, this is the default path for datasets.
91-
"""
94+
path (str): The path to scikit-uplift data dir
95+
96+
"""
9297
if path is None:
9398
path = get_data_dir()
9499
if os.path.isdir(path):
95100
shutil.rmtree(path, ignore_errors=True)
96101

97102

98-
def fetch_lenta(data_home=None, dest_subdir=None, download_if_missing=True, return_X_y_t=False, as_frame=False):
99-
"""Fetch the Lenta dataset.
100103

101-
Args:
102-
data_home (str, unicode): The path to the folder where datasets are stored.
103-
dest_subdir (str, unicode): The name of the folder in which the dataset is stored.
104-
download_if_missing (bool): Download the data if not present. Raises an IOError if False and data is missing.
105-
return_X_y_t (bool): If True, returns (data, target, treatment) instead of a Bunch object.
106-
See below for more information about the data and target object.
107-
as_frame (bool):
108-
109-
Returns:
110-
* dataset ('~sklearn.utils.Bunch'): Dictionary-like object, with the following attributes.
111-
* data (DataFrame object): Dataset without target and treatment.
112-
* target (Series object): Column target by values.
113-
* treatment (Series object): Column treatment by values.
114-
* DESCR (str): Description of the Lenta dataset.
115-
116-
* (data,target,treatment): tuple if 'return_X_y_t' is True.
104+
def fetch_lenta(return_X_y_t=False, data_home=None, dest_subdir=None, download_if_missing=True):
105+
"""Load and return the Lenta dataset (classification).
106+
107+
An uplift modeling dataset containing data about Lenta's customers grociery shopping and related marketing campaigns.
108+
109+
Major columns:
110+
111+
- ``group`` (str): treatment/control group flag
112+
- ``response_att`` (binary): target
113+
- ``gender`` (str): customer gender
114+
- ``age`` (float): customer age
115+
- ``main_format`` (int): store type (1 - grociery store, 0 - superstore)
116+
117+
Args:
118+
return_X_y_t (bool): If True, returns (data, target, treatment) instead of a Bunch object.
119+
See below for more information about the data and target object.
120+
data_home (str, unicode): The path to the folder where datasets are stored.
121+
dest_subdir (str, unicode): The name of the folder in which the dataset is stored.
122+
download_if_missing (bool): Download the data if not present. Raises an IOError if False and data is missing.
123+
124+
Returns:
125+
Bunch or tuple: dataset.
126+
127+
By default dictionary-like object, with the following attributes:
128+
129+
* ``data`` (DataFrame object): Dataset without target and treatment.
130+
* ``target`` (Series object): Column target by values.
131+
* ``treatment`` (Series object): Column treatment by values.
132+
* ``DESCR`` (str): Description of the Lenta dataset.
133+
134+
tuple (data, target, treatment) if `return_X_y` is True
117135
"""
118136

119-
url = 'https://winterschool123.s3.eu-north-1.amazonaws.com/lentadataset.csv.gz'
120-
filename = 'lentadataset.csv.gz'
121-
csv_path = _get_data(data_home=data_home, url=url, dest_subdir=dest_subdir,
122-
dest_filename=filename,
123-
download_if_missing=download_if_missing)
137+
url='https:/winterschool123.s3.eu-north-1.amazonaws.com/lentadataset.csv.gz'
138+
filename='lentadataset.csv.gz'
139+
140+
csv_path=_get_data(data_home=data_home, url=url, dest_subdir=dest_subdir,
141+
dest_filename=filename,
142+
download_if_missing=download_if_missing)
143+
124144
data = pd.read_csv(csv_path)
125145
if as_frame:
126146
target=data['response_att']
@@ -145,27 +165,33 @@ def fetch_lenta(data_home=None, dest_subdir=None, download_if_missing=True, retu
145165
feature_names=feature_names, target_name='response_att', treatment_name='group')
146166

147167

148-
def fetch_x5(data_home=None, dest_subdir=None, download_if_missing=True, as_frame=False):
149-
"""Fetch the X5 dataset.
168+
def fetch_x5(data_home=None, dest_subdir=None, download_if_missing=True):
169+
"""Load the X5 dataset.
170+
171+
The dataset contains raw retail customer purchaces, raw information about products and general info about customers.
172+
173+
Major columns:
174+
175+
- ``treatment_flg`` (binary): treatment/control group flag
176+
- ``target`` (binary): target
177+
- ``customer_id`` (str): customer id aka primary key for joining
150178
151179
Args:
152-
data_home (string): Specify a download and cache folder for the datasets.
153-
dest_subdir (string, unicode): The name of the folder in which the dataset is stored.
154-
download_if_missing (bool, default=True): If False, raise an IOError if the data is not locally available
155-
instead of trying to download the data from the source site.
156-
as_frame (bool, default=False):
180+
data_home (str, unicode): The path to the folder where datasets are stored.
181+
dest_subdir (str, unicode): The name of the folder in which the dataset is stored.
182+
download_if_missing (bool): Download the data if not present. Raises an IOError if False and data is missing.
157183
158184
Returns:
159-
'~sklearn.utils.Bunch': dataset
160-
Dictionary-like object, with the following attributes.
161-
data ('~sklearn.utils.Bunch'): Dataset without target and treatment.
162-
target (Series object): Column target by values
163-
treatment (Series object): Column treatment by values
164-
DESCR (str): Description of the X5 dataset.
165-
train (DataFrame object): Dataset with target and treatment.
166-
data_names ('~sklearn.utils.Bunch'): Names of features.
167-
treatment_name (string): The name of the treatment column.
185+
Bunch: dataset Dictionary-like object, with the following attributes.
186+
187+
* data ('~sklearn.utils.Bunch'): Dataset without target and treatment.
188+
* target (Series object): Column target by values
189+
* treatment (Series object): Column treatment by values
190+
* DESCR (str): Description of the X5 dataset.
191+
* train (DataFrame object): Dataset with target and treatment.
192+
168193
"""
194+
169195
url_clients = 'https://timds.s3.eu-central-1.amazonaws.com/clients.csv.gz'
170196
file_clients = 'clients.csv.gz'
171197
csv_clients_path = _get_data(data_home=data_home, url=url_clients, dest_subdir=dest_subdir,
@@ -213,8 +239,19 @@ def fetch_x5(data_home=None, dest_subdir=None, download_if_missing=True, as_fram
213239

214240
def fetch_criteo(data_home=None, dest_subdir=None, download_if_missing=True, percent10=True,
215241
treatment_feature='treatment', target_column='visit', return_X_y_t=False, as_frame=False):
216-
"""Load data from the Criteo dataset
217-
242+
"""Load data from the Criteo dataset.
243+
244+
This dataset is constructed by assembling data resulting from several incrementality tests, a particular randomized
245+
trial procedure where a random part of the population is prevented from being targeted by advertising.
246+
247+
Major columns:
248+
249+
* ``treatment`` (binary): treatment
250+
* ``exposure`` (binary): treatment
251+
* ``visit`` (binary): target
252+
* ``conversion`` (binary): target
253+
* ``f0, ... , f11`` (float): feature values
254+
218255
Args:
219256
data_home (string): Specify a download and cache folder for the datasets.
220257
dest_subdir (string, unicode): The name of the folder in which the dataset is stored.
@@ -227,7 +264,8 @@ def fetch_criteo(data_home=None, dest_subdir=None, download_if_missing=True, per
227264
will be target
228265
return_X_y_t (bool, default=False): If True, returns (data, target, treatment) instead of a Bunch object.
229266
See below for more information about the data and target object.
230-
as_frame (bool, default=False):
267+
as_frame (bool, default=False): If True, return as pandas.Series
268+
231269
Returns:
232270
''~sklearn.utils.Bunch'': dataset
233271
Dictionary-like object, with the following attributes.
@@ -300,29 +338,41 @@ def fetch_criteo(data_home=None, dest_subdir=None, download_if_missing=True, per
300338
def fetch_hillstrom(data_home=None, dest_subdir=None, download_if_missing=True, target_column='visit',
301339
return_X_y_t=False, as_frame=False):
302340
"""Load the hillstrom dataset.
341+
342+
This dataset contains 64,000 customers who last purchased within twelve months. The customers were involved in an e-mail test.
343+
344+
Major columns:
345+
346+
* ``Visit`` (binary): target. 1/0 indicator, 1 = Customer visited website in the following two weeks.
347+
* ``Conversion`` (binary): target. 1/0 indicator, 1 = Customer purchased merchandise in the following two weeks.
348+
* ``Spend`` (float): target. Actual dollars spent in the following two weeks.
349+
* ``Segment`` (str): treatment. The e-mail campaign the customer received
303350
304-
Args:
305-
data_home : str, default=None
306-
Specify another download and cache folder for the datasets.
307-
dest_subdir : str, default=None
308-
download_if_missing : bool, default=True
309-
If False, raise a IOError if the data is not locally available
310-
instead of trying to download the data from the source site.
351+
Args:
352+
target : str, desfault=visit.
353+
Can also be conversion, and spend
354+
data_home : str, default=None
355+
Specify another download and cache folder for the datasets.
356+
dest_subdir : str, default=None
357+
download_if_missing : bool, default=True
358+
If False, raise a IOError if the data is not locally available
359+
instead of trying to download the data from the source site.
311360
target_column (string, 'visit' or 'conversion' or 'spend', default='visit'): Selects which column from dataset
312-
will be target
361+
will be target
313362
return_X_y_t (bool):
314363
as_frame (bool):
315364
316-
Returns:
317-
Dictionary-like object, with the following attributes.
318-
data : {ndarray, dataframe} of shape (64000, 12)
365+
Returns:
366+
Dictionary-like object, with the following attributes.
367+
data : {ndarray, dataframe} of shape (64000, 12)
319368
The data matrix to learn.
320-
target : {ndarray, series} of shape (64000,)
369+
target : {ndarray, series} of shape (64000,)
321370
The regression target for each sample.
322-
treatment : {ndarray, series} of shape (64000,)
323-
feature_names (list): The names of the future columns
324-
target_name (string): The name of the target column.
325-
treatment_name (string): The name of the treatment column
371+
treatment : {ndarray, series} of shape (64000,)
372+
feature_names (list): The names of the future columns
373+
target_name (string): The name of the target column.
374+
treatment_name (string): The name of the treatment column
375+
326376
"""
327377

328378
url = 'https://hillstorm1.s3.us-east-2.amazonaws.com/hillstorm_no_indices.csv.gz'

0 commit comments

Comments
 (0)