Skip to content

Commit af9b099

Browse files
authored
Update dice-ml to 0.10.0 (#2145)
* Update dice-ml to 0.10.0 Signed-off-by: Gaurav Gupta <[email protected]> * Fix lint Signed-off-by: Gaurav Gupta <[email protected]> * Fix failing test Signed-off-by: Gaurav Gupta <[email protected]> --------- Signed-off-by: Gaurav Gupta <[email protected]>
1 parent 5bedefb commit af9b099

File tree

3 files changed

+4
-193
lines changed

3 files changed

+4
-193
lines changed

responsibleai/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
dice-ml>=0.9,<0.10
1+
dice-ml>=0.10,<0.11
22
econml>=0.14.1
33
statsmodels<0.14.0
44
jsonschema

responsibleai/tests/common_utils.py

Lines changed: 2 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22
# Licensed under the MIT License.
33

44
import json
5-
import os
6-
import shutil
7-
import zipfile
8-
from urllib.request import urlretrieve
95

106
import numpy as np
117
import pandas as pd
128
import pytest
9+
from dice_ml.utils import helpers
1310
# Defines common utilities for responsibleai tests
1411
from sklearn.model_selection import train_test_split
1512

@@ -44,191 +41,7 @@ def __init__(self):
4441
pass
4542

4643
def fetch(self):
47-
"""
48-
Loads adult income dataset from
49-
https://archive.ics.uci.edu/ml/datasets/Adult and prepares the
50-
data for data analysis based on https://rpubs.com/H_Zhu/235617
51-
52-
:return adult_data: returns preprocessed adult income dataset.
53-
"""
54-
# TODO: Revert to using load_adult_income_dataset once dice-ml has a
55-
# new release with the fix.
56-
# Download the adult dataset from
57-
# https://archive.ics.uci.edu/static/public/2/adult.zip as a zip folder
58-
outdirname = "adult"
59-
zipfilename = outdirname + ".zip"
60-
urlretrieve(
61-
"https://archive.ics.uci.edu/static/public/2/adult.zip",
62-
zipfilename,
63-
)
64-
with zipfile.ZipFile(zipfilename, "r") as unzip:
65-
unzip.extractall(outdirname)
66-
67-
raw_data = np.genfromtxt(
68-
outdirname + "/adult.data",
69-
delimiter=", ",
70-
dtype=str,
71-
invalid_raise=False,
72-
)
73-
74-
# column names from "https://archive.ics.uci.edu/ml/datasets/Adult"
75-
column_names = [
76-
"age",
77-
"workclass",
78-
"fnlwgt",
79-
"education",
80-
"educational-num",
81-
"marital-status",
82-
"occupation",
83-
"relationship",
84-
"race",
85-
"gender",
86-
"capital-gain",
87-
"capital-loss",
88-
"hours-per-week",
89-
"native-country",
90-
"income",
91-
]
92-
93-
adult_data = pd.DataFrame(raw_data, columns=column_names)
94-
95-
# For more details on how the below transformations are made,
96-
# please refer to https://rpubs.com/H_Zhu/235617
97-
adult_data = adult_data.astype(
98-
{
99-
"age": np.int64,
100-
"educational-num": np.int64,
101-
"hours-per-week": np.int64,
102-
}
103-
)
104-
105-
adult_data = adult_data.replace(
106-
{
107-
"workclass": {
108-
"Without-pay": "Other/Unknown",
109-
"Never-worked": "Other/Unknown",
110-
}
111-
}
112-
)
113-
adult_data = adult_data.replace(
114-
{
115-
"workclass": {
116-
"Federal-gov": "Government",
117-
"State-gov": "Government",
118-
"Local-gov": "Government",
119-
}
120-
}
121-
)
122-
adult_data = adult_data.replace(
123-
{
124-
"workclass": {
125-
"Self-emp-not-inc": "Self-Employed",
126-
"Self-emp-inc": "Self-Employed",
127-
}
128-
}
129-
)
130-
adult_data = adult_data.replace(
131-
{
132-
"workclass": {
133-
"Never-worked": "Self-Employed",
134-
"Without-pay": "Self-Employed",
135-
}
136-
}
137-
)
138-
adult_data = adult_data.replace({"workclass": {"?": "Other/Unknown"}})
139-
140-
adult_data = adult_data.replace(
141-
{
142-
"occupation": {
143-
"Adm-clerical": "White-Collar",
144-
"Craft-repair": "Blue-Collar",
145-
"Exec-managerial": "White-Collar",
146-
"Farming-fishing": "Blue-Collar",
147-
"Handlers-cleaners": "Blue-Collar",
148-
"Machine-op-inspct": "Blue-Collar",
149-
"Other-service": "Service",
150-
"Priv-house-serv": "Service",
151-
"Prof-specialty": "Professional",
152-
"Protective-serv": "Service",
153-
"Tech-support": "Service",
154-
"Transport-moving": "Blue-Collar",
155-
"Unknown": "Other/Unknown",
156-
"Armed-Forces": "Other/Unknown",
157-
"?": "Other/Unknown",
158-
}
159-
}
160-
)
161-
162-
adult_data = adult_data.replace(
163-
{
164-
"marital-status": {
165-
"Married-civ-spouse": "Married",
166-
"Married-AF-spouse": "Married",
167-
"Married-spouse-absent": "Married",
168-
"Never-married": "Single",
169-
}
170-
}
171-
)
172-
173-
adult_data = adult_data.replace(
174-
{
175-
"race": {
176-
"Black": "Other",
177-
"Asian-Pac-Islander": "Other",
178-
"Amer-Indian-Eskimo": "Other",
179-
}
180-
}
181-
)
182-
183-
adult_data = adult_data[
184-
[
185-
"age",
186-
"workclass",
187-
"education",
188-
"marital-status",
189-
"occupation",
190-
"race",
191-
"gender",
192-
"hours-per-week",
193-
"income",
194-
]
195-
]
196-
197-
adult_data = adult_data.replace({"income": {"<=50K": 0, ">50K": 1}})
198-
199-
adult_data = adult_data.replace(
200-
{
201-
"education": {
202-
"Assoc-voc": "Assoc",
203-
"Assoc-acdm": "Assoc",
204-
"11th": "School",
205-
"10th": "School",
206-
"7th-8th": "School",
207-
"9th": "School",
208-
"12th": "School",
209-
"5th-6th": "School",
210-
"1st-4th": "School",
211-
"Preschool": "School",
212-
}
213-
}
214-
)
215-
216-
adult_data = adult_data.rename(
217-
columns={
218-
"marital-status": "marital_status",
219-
"hours-per-week": "hours_per_week",
220-
}
221-
)
222-
223-
train, _ = train_test_split(adult_data, test_size=0.2, random_state=17)
224-
adult_data = train.reset_index(drop=True)
225-
226-
# Remove the downloaded dataset
227-
if os.path.isdir(outdirname):
228-
entire_path = os.path.abspath(outdirname)
229-
shutil.rmtree(entire_path)
230-
231-
return adult_data
44+
return helpers.load_adult_income_dataset()
23245

23346

23447
def create_adult_income_dataset(create_small_dataset=True):

responsibleai/tests/counterfactual_manager_validator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
# Licensed under the MIT License.
33

44
import pytest
5-
from dice_ml.utils.exception import \
6-
UserConfigValidationException as DiceException
75

86
from raiutils.exceptions import UserConfigValidationException
97
from responsibleai._internal.constants import (CounterfactualManagerKeys,
@@ -112,7 +110,7 @@ def validate_counterfactual(cf_analyzer,
112110
counterfactual_props=cf_analyzer.counterfactual.list(),
113111
expected_counterfactuals=2)
114112
else:
115-
with pytest.raises(DiceException):
113+
with pytest.raises(UserConfigValidationException):
116114
cf_analyzer.counterfactual.add(
117115
total_CFs=-2,
118116
method='random',

0 commit comments

Comments
 (0)