|
24 | 24 | ARISING IN ANY WAY OUT OF THE USE OF THE SOFTWARE CODE, EVEN IF ADVISED OF THE |
25 | 25 | POSSIBILITY OF SUCH DAMAGE. |
26 | 26 | """ |
27 | | - |
28 | | -import sys |
29 | 27 | import os |
30 | 28 | import numpy as np |
31 | 29 | import pandas as pd |
32 | 30 |
|
| 31 | + |
| 32 | +# get absolute path of csv files from data folder |
| 33 | +def get_absPath(filename): |
| 34 | + """Returns the path of the notebooks folder""" |
| 35 | + path = os.path.abspath( |
| 36 | + os.path.join(os.path.dirname(__file__), os.path.pardir, |
| 37 | + os.path.pardir, "data", filename)) |
| 38 | + return path |
| 39 | + |
| 40 | + |
33 | 41 | # number of features |
34 | | -n_columns = 10 |
| 42 | +expected_columns = 10 |
35 | 43 |
|
36 | 44 | # distribution of features in the training set |
37 | 45 | historical_mean = np.array( |
|
65 | 73 | ] |
66 | 74 | ) |
67 | 75 |
|
68 | | -# maximal relative change in feature mean or standrd deviation that we can tolerate |
| 76 | +# maximal relative change in feature mean or standrd deviation |
| 77 | +# that we can tolerate |
69 | 78 | shift_tolerance = 3 |
70 | 79 |
|
71 | 80 |
|
72 | | -def check_schema(X): |
73 | | - n_actual_columns = X.shape[1] |
74 | | - if n_actual_columns != n_columns: |
75 | | - print( |
76 | | - "Error: found {} feature columns. The data should have {} feature columns.".format( |
77 | | - n_actual_columns, n_columns |
78 | | - ) |
79 | | - ) |
80 | | - return False |
81 | | - |
82 | | - return True |
83 | | - |
84 | | - |
85 | | -def check_missing_values(dataset): |
| 81 | +def test_check_schema(): |
| 82 | + datafile = get_absPath("diabetes.csv") |
| 83 | + # check that file exists |
| 84 | + assert(os.path.exists(datafile)) |
| 85 | + dataset = pd.read_csv(datafile) |
| 86 | + header = dataset[dataset.columns[:-1]] |
| 87 | + actual_columns = header.shape[1] |
| 88 | + # check header has expected number of columns |
| 89 | + assert(actual_columns == expected_columns) |
| 90 | + |
| 91 | + |
| 92 | +def test_check_bad_schema(): |
| 93 | + datafile = get_absPath("diabetes_bad_schema.csv") |
| 94 | + # check that file exists |
| 95 | + assert(os.path.exists(datafile)) |
| 96 | + dataset = pd.read_csv(datafile) |
| 97 | + header = dataset[dataset.columns[:-1]] |
| 98 | + actual_columns = header.shape[1] |
| 99 | + # check header has expected number of columns |
| 100 | + assert(actual_columns != expected_columns) |
| 101 | + |
| 102 | + |
| 103 | +def test_check_missing_values(): |
| 104 | + datafile = get_absPath("diabetes_missing_values.csv") |
| 105 | + # check that file exists |
| 106 | + assert(os.path.exists(datafile)) |
| 107 | + dataset = pd.read_csv(datafile) |
86 | 108 | n_nan = np.sum(np.isnan(dataset.values)) |
87 | | - if n_nan > 0: |
88 | | - print("Warning: the data has {} missing values".format(n_nan)) |
89 | | - return False |
90 | | - return True |
| 109 | + assert(n_nan > 0) |
91 | 110 |
|
92 | 111 |
|
93 | | -def check_distribution(dataset): |
| 112 | +def test_check_distribution(): |
| 113 | + datafile = get_absPath("diabetes_bad_dist.csv") |
| 114 | + # check that file exists |
| 115 | + assert(os.path.exists(datafile)) |
| 116 | + dataset = pd.read_csv(datafile) |
94 | 117 | mean = np.mean(dataset.values, axis=0) |
95 | 118 | std = np.mean(dataset.values, axis=0) |
96 | | - if ( |
97 | | - np.sum(abs(mean - historical_mean) > shift_tolerance * abs(historical_mean)) > 0 |
98 | | - or np.sum(abs(std - historical_std) > shift_tolerance * abs(historical_std)) > 0 |
99 | | - ): |
100 | | - print("Warning: new data has different distribution than the training data") |
101 | | - return False |
102 | | - return True |
103 | | - |
104 | | - |
105 | | -def main(): |
106 | | - filename = sys.argv[1] |
107 | | - if not os.path.exists(filename): |
108 | | - print("Error: The file {} does not exist".format(filename)) |
109 | | - return |
110 | | - |
111 | | - dataset = pd.read_csv(filename) |
112 | | - if check_schema(dataset[dataset.columns[:-1]]): |
113 | | - print("Data schema test succeeded") |
114 | | - if check_missing_values(dataset) and check_distribution(dataset): |
115 | | - print("Missing values test passed") |
116 | | - print("Data distribution test passed") |
117 | | - else: |
118 | | - print( |
119 | | - "There might be some issues with the data. Please check warning messages." |
120 | | - ) |
121 | | - |
122 | | - |
123 | | -if __name__ == "__main__": |
124 | | - main() |
| 119 | + assert(np.sum(abs(mean - historical_mean) > shift_tolerance * |
| 120 | + abs(historical_mean)) or |
| 121 | + np.sum(abs(std - historical_std) > shift_tolerance * |
| 122 | + abs(historical_std)) > 0) |
0 commit comments