|
2 | 2 | # Licensed under the MIT License. |
3 | 3 |
|
4 | 4 | import json |
5 | | -import os |
6 | | -import shutil |
7 | | -import zipfile |
8 | | -from urllib.request import urlretrieve |
9 | 5 |
|
10 | 6 | import numpy as np |
11 | 7 | import pandas as pd |
12 | 8 | import pytest |
| 9 | +from dice_ml.utils import helpers |
13 | 10 | # Defines common utilities for responsibleai tests |
14 | 11 | from sklearn.model_selection import train_test_split |
15 | 12 |
|
@@ -44,191 +41,7 @@ def __init__(self): |
44 | 41 | pass |
45 | 42 |
|
46 | 43 | def fetch(self): |
47 | | - """ |
48 | | - Loads adult income dataset from |
49 | | - https://archive.ics.uci.edu/ml/datasets/Adult and prepares the |
50 | | - data for data analysis based on https://rpubs.com/H_Zhu/235617 |
51 | | -
|
52 | | - :return adult_data: returns preprocessed adult income dataset. |
53 | | - """ |
54 | | - # TODO: Revert to using load_adult_income_dataset once dice-ml has a |
55 | | - # new release with the fix. |
56 | | - # Download the adult dataset from |
57 | | - # https://archive.ics.uci.edu/static/public/2/adult.zip as a zip folder |
58 | | - outdirname = "adult" |
59 | | - zipfilename = outdirname + ".zip" |
60 | | - urlretrieve( |
61 | | - "https://archive.ics.uci.edu/static/public/2/adult.zip", |
62 | | - zipfilename, |
63 | | - ) |
64 | | - with zipfile.ZipFile(zipfilename, "r") as unzip: |
65 | | - unzip.extractall(outdirname) |
66 | | - |
67 | | - raw_data = np.genfromtxt( |
68 | | - outdirname + "/adult.data", |
69 | | - delimiter=", ", |
70 | | - dtype=str, |
71 | | - invalid_raise=False, |
72 | | - ) |
73 | | - |
74 | | - # column names from "https://archive.ics.uci.edu/ml/datasets/Adult" |
75 | | - column_names = [ |
76 | | - "age", |
77 | | - "workclass", |
78 | | - "fnlwgt", |
79 | | - "education", |
80 | | - "educational-num", |
81 | | - "marital-status", |
82 | | - "occupation", |
83 | | - "relationship", |
84 | | - "race", |
85 | | - "gender", |
86 | | - "capital-gain", |
87 | | - "capital-loss", |
88 | | - "hours-per-week", |
89 | | - "native-country", |
90 | | - "income", |
91 | | - ] |
92 | | - |
93 | | - adult_data = pd.DataFrame(raw_data, columns=column_names) |
94 | | - |
95 | | - # For more details on how the below transformations are made, |
96 | | - # please refer to https://rpubs.com/H_Zhu/235617 |
97 | | - adult_data = adult_data.astype( |
98 | | - { |
99 | | - "age": np.int64, |
100 | | - "educational-num": np.int64, |
101 | | - "hours-per-week": np.int64, |
102 | | - } |
103 | | - ) |
104 | | - |
105 | | - adult_data = adult_data.replace( |
106 | | - { |
107 | | - "workclass": { |
108 | | - "Without-pay": "Other/Unknown", |
109 | | - "Never-worked": "Other/Unknown", |
110 | | - } |
111 | | - } |
112 | | - ) |
113 | | - adult_data = adult_data.replace( |
114 | | - { |
115 | | - "workclass": { |
116 | | - "Federal-gov": "Government", |
117 | | - "State-gov": "Government", |
118 | | - "Local-gov": "Government", |
119 | | - } |
120 | | - } |
121 | | - ) |
122 | | - adult_data = adult_data.replace( |
123 | | - { |
124 | | - "workclass": { |
125 | | - "Self-emp-not-inc": "Self-Employed", |
126 | | - "Self-emp-inc": "Self-Employed", |
127 | | - } |
128 | | - } |
129 | | - ) |
130 | | - adult_data = adult_data.replace( |
131 | | - { |
132 | | - "workclass": { |
133 | | - "Never-worked": "Self-Employed", |
134 | | - "Without-pay": "Self-Employed", |
135 | | - } |
136 | | - } |
137 | | - ) |
138 | | - adult_data = adult_data.replace({"workclass": {"?": "Other/Unknown"}}) |
139 | | - |
140 | | - adult_data = adult_data.replace( |
141 | | - { |
142 | | - "occupation": { |
143 | | - "Adm-clerical": "White-Collar", |
144 | | - "Craft-repair": "Blue-Collar", |
145 | | - "Exec-managerial": "White-Collar", |
146 | | - "Farming-fishing": "Blue-Collar", |
147 | | - "Handlers-cleaners": "Blue-Collar", |
148 | | - "Machine-op-inspct": "Blue-Collar", |
149 | | - "Other-service": "Service", |
150 | | - "Priv-house-serv": "Service", |
151 | | - "Prof-specialty": "Professional", |
152 | | - "Protective-serv": "Service", |
153 | | - "Tech-support": "Service", |
154 | | - "Transport-moving": "Blue-Collar", |
155 | | - "Unknown": "Other/Unknown", |
156 | | - "Armed-Forces": "Other/Unknown", |
157 | | - "?": "Other/Unknown", |
158 | | - } |
159 | | - } |
160 | | - ) |
161 | | - |
162 | | - adult_data = adult_data.replace( |
163 | | - { |
164 | | - "marital-status": { |
165 | | - "Married-civ-spouse": "Married", |
166 | | - "Married-AF-spouse": "Married", |
167 | | - "Married-spouse-absent": "Married", |
168 | | - "Never-married": "Single", |
169 | | - } |
170 | | - } |
171 | | - ) |
172 | | - |
173 | | - adult_data = adult_data.replace( |
174 | | - { |
175 | | - "race": { |
176 | | - "Black": "Other", |
177 | | - "Asian-Pac-Islander": "Other", |
178 | | - "Amer-Indian-Eskimo": "Other", |
179 | | - } |
180 | | - } |
181 | | - ) |
182 | | - |
183 | | - adult_data = adult_data[ |
184 | | - [ |
185 | | - "age", |
186 | | - "workclass", |
187 | | - "education", |
188 | | - "marital-status", |
189 | | - "occupation", |
190 | | - "race", |
191 | | - "gender", |
192 | | - "hours-per-week", |
193 | | - "income", |
194 | | - ] |
195 | | - ] |
196 | | - |
197 | | - adult_data = adult_data.replace({"income": {"<=50K": 0, ">50K": 1}}) |
198 | | - |
199 | | - adult_data = adult_data.replace( |
200 | | - { |
201 | | - "education": { |
202 | | - "Assoc-voc": "Assoc", |
203 | | - "Assoc-acdm": "Assoc", |
204 | | - "11th": "School", |
205 | | - "10th": "School", |
206 | | - "7th-8th": "School", |
207 | | - "9th": "School", |
208 | | - "12th": "School", |
209 | | - "5th-6th": "School", |
210 | | - "1st-4th": "School", |
211 | | - "Preschool": "School", |
212 | | - } |
213 | | - } |
214 | | - ) |
215 | | - |
216 | | - adult_data = adult_data.rename( |
217 | | - columns={ |
218 | | - "marital-status": "marital_status", |
219 | | - "hours-per-week": "hours_per_week", |
220 | | - } |
221 | | - ) |
222 | | - |
223 | | - train, _ = train_test_split(adult_data, test_size=0.2, random_state=17) |
224 | | - adult_data = train.reset_index(drop=True) |
225 | | - |
226 | | - # Remove the downloaded dataset |
227 | | - if os.path.isdir(outdirname): |
228 | | - entire_path = os.path.abspath(outdirname) |
229 | | - shutil.rmtree(entire_path) |
230 | | - |
231 | | - return adult_data |
| 44 | + return helpers.load_adult_income_dataset() |
232 | 45 |
|
233 | 46 |
|
234 | 47 | def create_adult_income_dataset(create_small_dataset=True): |
|
0 commit comments