Skip to content

Commit e86d9f6

Browse files
authored
Merge pull request #5 from Microsoft/rijai-unittest
Rewrote existing data test to follow pytest format
2 parents 5a49ac0 + addf7a7 commit e86d9f6

File tree

3 files changed

+53
-54
lines changed

3 files changed

+53
-54
lines changed

azure-pipelines.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ steps:
2828
2929
displayName: 'replace subscription value'
3030

31-
- script: 'python code/testing/data_test.py data/diabetes.csv && python code/testing/data_test.py data/diabetes_bad_dist.csv && python code/testing/data_test.py data/diabetes_bad_schema.csv && python code/testing/data_test.py data/diabetes_missing_values.csv'
31+
- script: 'pytest tests/unit/data_test.py'
3232
displayName: 'Data Quality Check'
3333

3434
- script: 'python aml_service/00-WorkSpace.py'

environment_setup/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
scipy==1.0.0
22
scikit-learn==0.19.1
33
numpy==1.14.5
4-
pandas==0.23.1
4+
pandas==0.23.1
5+
pytest==4.3.0
Lines changed: 50 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,22 @@
2424
ARISING IN ANY WAY OUT OF THE USE OF THE SOFTWARE CODE, EVEN IF ADVISED OF THE
2525
POSSIBILITY OF SUCH DAMAGE.
2626
"""
27-
28-
import sys
2927
import os
3028
import numpy as np
3129
import pandas as pd
3230

31+
32+
# get absolute path of csv files from data folder
33+
def get_absPath(filename):
34+
"""Returns the path of the notebooks folder"""
35+
path = os.path.abspath(
36+
os.path.join(os.path.dirname(__file__), os.path.pardir,
37+
os.path.pardir, "data", filename))
38+
return path
39+
40+
3341
# number of features
34-
n_columns = 10
42+
expected_columns = 10
3543

3644
# distribution of features in the training set
3745
historical_mean = np.array(
@@ -65,60 +73,50 @@
6573
]
6674
)
6775

68-
# maximal relative change in feature mean or standrd deviation that we can tolerate
76+
# maximal relative change in feature mean or standrd deviation
77+
# that we can tolerate
6978
shift_tolerance = 3
7079

7180

72-
def check_schema(X):
73-
n_actual_columns = X.shape[1]
74-
if n_actual_columns != n_columns:
75-
print(
76-
"Error: found {} feature columns. The data should have {} feature columns.".format(
77-
n_actual_columns, n_columns
78-
)
79-
)
80-
return False
81-
82-
return True
83-
84-
85-
def check_missing_values(dataset):
81+
def test_check_schema():
82+
datafile = get_absPath("diabetes.csv")
83+
# check that file exists
84+
assert(os.path.exists(datafile))
85+
dataset = pd.read_csv(datafile)
86+
header = dataset[dataset.columns[:-1]]
87+
actual_columns = header.shape[1]
88+
# check header has expected number of columns
89+
assert(actual_columns == expected_columns)
90+
91+
92+
def test_check_bad_schema():
93+
datafile = get_absPath("diabetes_bad_schema.csv")
94+
# check that file exists
95+
assert(os.path.exists(datafile))
96+
dataset = pd.read_csv(datafile)
97+
header = dataset[dataset.columns[:-1]]
98+
actual_columns = header.shape[1]
99+
# check header has expected number of columns
100+
assert(actual_columns != expected_columns)
101+
102+
103+
def test_check_missing_values():
104+
datafile = get_absPath("diabetes_missing_values.csv")
105+
# check that file exists
106+
assert(os.path.exists(datafile))
107+
dataset = pd.read_csv(datafile)
86108
n_nan = np.sum(np.isnan(dataset.values))
87-
if n_nan > 0:
88-
print("Warning: the data has {} missing values".format(n_nan))
89-
return False
90-
return True
109+
assert(n_nan > 0)
91110

92111

93-
def check_distribution(dataset):
112+
def test_check_distribution():
113+
datafile = get_absPath("diabetes_bad_dist.csv")
114+
# check that file exists
115+
assert(os.path.exists(datafile))
116+
dataset = pd.read_csv(datafile)
94117
mean = np.mean(dataset.values, axis=0)
95118
std = np.mean(dataset.values, axis=0)
96-
if (
97-
np.sum(abs(mean - historical_mean) > shift_tolerance * abs(historical_mean)) > 0
98-
or np.sum(abs(std - historical_std) > shift_tolerance * abs(historical_std)) > 0
99-
):
100-
print("Warning: new data has different distribution than the training data")
101-
return False
102-
return True
103-
104-
105-
def main():
106-
filename = sys.argv[1]
107-
if not os.path.exists(filename):
108-
print("Error: The file {} does not exist".format(filename))
109-
return
110-
111-
dataset = pd.read_csv(filename)
112-
if check_schema(dataset[dataset.columns[:-1]]):
113-
print("Data schema test succeeded")
114-
if check_missing_values(dataset) and check_distribution(dataset):
115-
print("Missing values test passed")
116-
print("Data distribution test passed")
117-
else:
118-
print(
119-
"There might be some issues with the data. Please check warning messages."
120-
)
121-
122-
123-
if __name__ == "__main__":
124-
main()
119+
assert(np.sum(abs(mean - historical_mean) > shift_tolerance *
120+
abs(historical_mean)) or
121+
np.sum(abs(std - historical_std) > shift_tolerance *
122+
abs(historical_std)) > 0)

0 commit comments

Comments
 (0)