Skip to content

Commit 9f2c0f0

Browse files
author
Nabil Fayak
committed
uncommented target distr test
1 parent 2e4bc52 commit 9f2c0f0

File tree

1 file changed

+133
-133
lines changed

1 file changed

+133
-133
lines changed
Lines changed: 133 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1,133 +1,133 @@
1-
# import numpy as np
2-
# import pandas as pd
3-
# import pytest
4-
# from scipy.stats import jarque_bera, lognorm, norm, shapiro
5-
6-
# from checkmates.data_checks import (
7-
# DataCheckActionCode,
8-
# DataCheckActionOption,
9-
# DataCheckError,
10-
# DataCheckMessageCode,
11-
# DataCheckWarning,
12-
# TargetDistributionDataCheck,
13-
# )
14-
# from checkmates.utils import infer_feature_types
15-
16-
# target_dist_check_name = TargetDistributionDataCheck.name
17-
18-
19-
# def test_target_distribution_data_check_no_y(X_y_regression):
20-
# X, y = X_y_regression
21-
# y = None
22-
23-
# target_dist_check = TargetDistributionDataCheck()
24-
25-
# assert target_dist_check.validate(X, y) == [
26-
# DataCheckError(
27-
# message="Target is None",
28-
# data_check_name=target_dist_check_name,
29-
# message_code=DataCheckMessageCode.TARGET_IS_NONE,
30-
# details={},
31-
# ).to_dict(),
32-
# ]
33-
34-
35-
# @pytest.mark.parametrize("target_type", ["boolean", "categorical", "integer", "double"])
36-
# def test_target_distribution_data_check_unsupported_target_type(target_type):
37-
# X = pd.DataFrame(range(5))
38-
39-
# if target_type == "boolean":
40-
# y = pd.Series([True, False] * 5)
41-
# elif target_type == "categorical":
42-
# y = pd.Series(["One", "Two", "Three", "Four", "Five"] * 2)
43-
# elif target_type == "integer":
44-
# y = [-1, -3, -5, 4, -2, 4, -4, 2, 1, 1]
45-
# else:
46-
# y = [9.2, 7.66, 4.93, 3.29, 4.06, -1.28, 4.95, 6.77, 9.07, 7.67]
47-
48-
# y = infer_feature_types(y)
49-
50-
# target_dist_check = TargetDistributionDataCheck()
51-
52-
# if target_type in ["integer", "double"]:
53-
# assert target_dist_check.validate(X, y) == []
54-
# else:
55-
# assert target_dist_check.validate(X, y) == [
56-
# DataCheckError(
57-
# message=f"Target is unsupported {y.ww.logical_type.type_string} type. Valid Woodwork logical types include: integer, double, age, age_fractional",
58-
# data_check_name=target_dist_check_name,
59-
# message_code=DataCheckMessageCode.TARGET_UNSUPPORTED_TYPE,
60-
# details={"unsupported_type": y.ww.logical_type.type_string},
61-
# ).to_dict(),
62-
# ]
63-
64-
65-
# @pytest.mark.parametrize("data_type", ["positive", "mixed", "negative"])
66-
# @pytest.mark.parametrize("distribution", ["normal", "lognormal", "very_lognormal"])
67-
# @pytest.mark.parametrize(
68-
# "size,name,statistic",
69-
# [(10000, "jarque_bera", jarque_bera), (5000, "shapiro", shapiro)],
70-
# )
71-
# def test_target_distribution_data_check_warning_action(
72-
# size,
73-
# name,
74-
# statistic,
75-
# distribution,
76-
# data_type,
77-
# X_y_regression,
78-
# ):
79-
# X, y = X_y_regression
80-
# # set this to avoid flaky tests. This is primarily because when we have smaller samples,
81-
# # once we remove values outside 3 st.devs, the distribution can begin to look more normal
82-
# random_state = 2
83-
# target_dist_check = TargetDistributionDataCheck()
84-
85-
# if distribution == "normal":
86-
# y = norm.rvs(loc=3, size=size, random_state=random_state)
87-
# elif distribution == "lognormal":
88-
# y = lognorm.rvs(0.4, size=size, random_state=random_state)
89-
# else:
90-
# # Will have a p-value of 0 thereby rejecting the null hypothesis even after log transforming
91-
# # This is essentially just checking the = of the statistic's "log.pvalue >= og.pvalue"
92-
# y = lognorm.rvs(s=1, loc=1, scale=1, size=size, random_state=random_state)
93-
94-
# y = np.round(y, 6)
95-
96-
# if data_type == "negative":
97-
# y = -np.abs(y)
98-
# elif data_type == "mixed":
99-
# y = y - 1.2
100-
101-
# if distribution == "normal":
102-
# assert target_dist_check.validate(X, y) == []
103-
# else:
104-
# target_dist_ = target_dist_check.validate(X, y)
105-
106-
# if any(y <= 0):
107-
# y = y + abs(y.min()) + 1
108-
# y = y[y < (y.mean() + 3 * round(y.std(), 3))]
109-
# test_og = statistic(y)
110-
111-
# details = {
112-
# "normalization_method": name,
113-
# "statistic": round(test_og.statistic, 1),
114-
# "p-value": round(test_og.pvalue, 3),
115-
# }
116-
# assert target_dist_ == [
117-
# DataCheckWarning(
118-
# message="Target may have a lognormal distribution.",
119-
# data_check_name=target_dist_check_name,
120-
# message_code=DataCheckMessageCode.TARGET_LOGNORMAL_DISTRIBUTION,
121-
# details=details,
122-
# action_options=[
123-
# DataCheckActionOption(
124-
# DataCheckActionCode.TRANSFORM_TARGET,
125-
# data_check_name=target_dist_check_name,
126-
# metadata={
127-
# "is_target": True,
128-
# "transformation_strategy": "lognormal",
129-
# },
130-
# ),
131-
# ],
132-
# ).to_dict(),
133-
# ]
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
from scipy.stats import jarque_bera, lognorm, norm, shapiro
5+
6+
from checkmates.data_checks import (
7+
DataCheckActionCode,
8+
DataCheckActionOption,
9+
DataCheckError,
10+
DataCheckMessageCode,
11+
DataCheckWarning,
12+
TargetDistributionDataCheck,
13+
)
14+
from checkmates.utils import infer_feature_types
15+
16+
target_dist_check_name = TargetDistributionDataCheck.name
17+
18+
19+
def test_target_distribution_data_check_no_y(X_y_regression):
20+
X, y = X_y_regression
21+
y = None
22+
23+
target_dist_check = TargetDistributionDataCheck()
24+
25+
assert target_dist_check.validate(X, y) == [
26+
DataCheckError(
27+
message="Target is None",
28+
data_check_name=target_dist_check_name,
29+
message_code=DataCheckMessageCode.TARGET_IS_NONE,
30+
details={},
31+
).to_dict(),
32+
]
33+
34+
35+
@pytest.mark.parametrize("target_type", ["boolean", "categorical", "integer", "double"])
36+
def test_target_distribution_data_check_unsupported_target_type(target_type):
37+
X = pd.DataFrame(range(5))
38+
39+
if target_type == "boolean":
40+
y = pd.Series([True, False] * 5)
41+
elif target_type == "categorical":
42+
y = pd.Series(["One", "Two", "Three", "Four", "Five"] * 2)
43+
elif target_type == "integer":
44+
y = [-1, -3, -5, 4, -2, 4, -4, 2, 1, 1]
45+
else:
46+
y = [9.2, 7.66, 4.93, 3.29, 4.06, -1.28, 4.95, 6.77, 9.07, 7.67]
47+
48+
y = infer_feature_types(y)
49+
50+
target_dist_check = TargetDistributionDataCheck()
51+
52+
if target_type in ["integer", "double"]:
53+
assert target_dist_check.validate(X, y) == []
54+
else:
55+
assert target_dist_check.validate(X, y) == [
56+
DataCheckError(
57+
message=f"Target is unsupported {y.ww.logical_type.type_string} type. Valid Woodwork logical types include: integer, double, age, age_fractional",
58+
data_check_name=target_dist_check_name,
59+
message_code=DataCheckMessageCode.TARGET_UNSUPPORTED_TYPE,
60+
details={"unsupported_type": y.ww.logical_type.type_string},
61+
).to_dict(),
62+
]
63+
64+
65+
@pytest.mark.parametrize("data_type", ["positive", "mixed", "negative"])
66+
@pytest.mark.parametrize("distribution", ["normal", "lognormal", "very_lognormal"])
67+
@pytest.mark.parametrize(
68+
"size,name,statistic",
69+
[(10000, "jarque_bera", jarque_bera), (5000, "shapiro", shapiro)],
70+
)
71+
def test_target_distribution_data_check_warning_action(
72+
size,
73+
name,
74+
statistic,
75+
distribution,
76+
data_type,
77+
X_y_regression,
78+
):
79+
X, y = X_y_regression
80+
# set this to avoid flaky tests. This is primarily because when we have smaller samples,
81+
# once we remove values outside 3 st.devs, the distribution can begin to look more normal
82+
random_state = 2
83+
target_dist_check = TargetDistributionDataCheck()
84+
85+
if distribution == "normal":
86+
y = norm.rvs(loc=3, size=size, random_state=random_state)
87+
elif distribution == "lognormal":
88+
y = lognorm.rvs(0.4, size=size, random_state=random_state)
89+
else:
90+
# Will have a p-value of 0 thereby rejecting the null hypothesis even after log transforming
91+
# This is essentially just checking the = of the statistic's "log.pvalue >= og.pvalue"
92+
y = lognorm.rvs(s=1, loc=1, scale=1, size=size, random_state=random_state)
93+
94+
y = np.round(y, 6)
95+
96+
if data_type == "negative":
97+
y = -np.abs(y)
98+
elif data_type == "mixed":
99+
y = y - 1.2
100+
101+
if distribution == "normal":
102+
assert target_dist_check.validate(X, y) == []
103+
else:
104+
target_dist_ = target_dist_check.validate(X, y)
105+
106+
if any(y <= 0):
107+
y = y + abs(y.min()) + 1
108+
y = y[y < (y.mean() + 3 * round(y.std(), 3))]
109+
test_og = statistic(y)
110+
111+
details = {
112+
"normalization_method": name,
113+
"statistic": round(test_og.statistic, 1),
114+
"p-value": round(test_og.pvalue, 3),
115+
}
116+
assert target_dist_ == [
117+
DataCheckWarning(
118+
message="Target may have a lognormal distribution.",
119+
data_check_name=target_dist_check_name,
120+
message_code=DataCheckMessageCode.TARGET_LOGNORMAL_DISTRIBUTION,
121+
details=details,
122+
action_options=[
123+
DataCheckActionOption(
124+
DataCheckActionCode.TRANSFORM_TARGET,
125+
data_check_name=target_dist_check_name,
126+
metadata={
127+
"is_target": True,
128+
"transformation_strategy": "lognormal",
129+
},
130+
),
131+
],
132+
).to_dict(),
133+
]

0 commit comments

Comments
 (0)