12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import pandas
16
- import pytest
17
-
18
15
from bigframes .ml import globals
16
+ from tests .system import utils
19
17
20
18
21
- # TODO(garrettwu): Re-enable or not check exact numbers.
22
- @pytest .mark .skip (reason = "bqml regression" )
23
19
def test_bqml_e2e (session , dataset_id , penguins_df_default_index , new_penguins_df ):
24
20
df = penguins_df_default_index .dropna ()
25
21
X_train = df [
@@ -38,41 +34,33 @@ def test_bqml_e2e(session, dataset_id, penguins_df_default_index, new_penguins_d
38
34
X_train , y_train , options = {"model_type" : "linear_reg" }
39
35
)
40
36
37
+ eval_metrics = [
38
+ "mean_absolute_error" ,
39
+ "mean_squared_error" ,
40
+ "mean_squared_log_error" ,
41
+ "median_absolute_error" ,
42
+ "r2_score" ,
43
+ "explained_variance" ,
44
+ ]
41
45
# no data - report evaluation from the automatic data split
42
46
evaluate_result = model .evaluate ().to_pandas ()
43
- evaluate_expected = pandas .DataFrame (
44
- {
45
- "mean_absolute_error" : [225.817334 ],
46
- "mean_squared_error" : [80540.705944 ],
47
- "mean_squared_log_error" : [0.004972 ],
48
- "median_absolute_error" : [173.080816 ],
49
- "r2_score" : [0.87529 ],
50
- "explained_variance" : [0.87529 ],
51
- },
52
- dtype = "Float64" ,
53
- )
54
- evaluate_expected = evaluate_expected .reindex (
55
- index = evaluate_expected .index .astype ("Int64" )
56
- )
57
- pandas .testing .assert_frame_equal (
58
- evaluate_result , evaluate_expected , check_exact = False , rtol = 0.1
47
+ utils .check_pandas_df_schema_and_index (
48
+ evaluate_result , columns = eval_metrics , index = 1
59
49
)
60
50
61
51
# evaluate on all training data
62
52
evaluate_result = model .evaluate (df ).to_pandas ()
63
- pandas . testing . assert_frame_equal (
64
- evaluate_result , evaluate_expected , check_exact = False , rtol = 0. 1
53
+ utils . check_pandas_df_schema_and_index (
54
+ evaluate_result , columns = eval_metrics , index = 1
65
55
)
66
56
67
57
# predict new labels
68
58
predictions = model .predict (new_penguins_df ).to_pandas ()
69
- expected = pandas .DataFrame (
70
- {"predicted_body_mass_g" : [4030.1 , 3280.8 , 3177.9 ]},
71
- dtype = "Float64" ,
72
- index = pandas .Index ([1633 , 1672 , 1690 ], name = "tag_number" , dtype = "Int64" ),
73
- )
74
- pandas .testing .assert_frame_equal (
75
- predictions [["predicted_body_mass_g" ]], expected , check_exact = False , rtol = 0.1
59
+ utils .check_pandas_df_schema_and_index (
60
+ predictions ,
61
+ columns = ["predicted_body_mass_g" ],
62
+ index = [1633 , 1672 , 1690 ],
63
+ col_exact = False ,
76
64
)
77
65
78
66
new_name = f"{ dataset_id } .my_model"
@@ -108,42 +96,34 @@ def test_bqml_manual_preprocessing_e2e(
108
96
X_train , y_train , transforms = transforms , options = options
109
97
)
110
98
99
+ eval_metrics = [
100
+ "mean_absolute_error" ,
101
+ "mean_squared_error" ,
102
+ "mean_squared_log_error" ,
103
+ "median_absolute_error" ,
104
+ "r2_score" ,
105
+ "explained_variance" ,
106
+ ]
107
+
111
108
# no data - report evaluation from the automatic data split
112
109
evaluate_result = model .evaluate ().to_pandas ()
113
- evaluate_expected = pandas .DataFrame (
114
- {
115
- "mean_absolute_error" : [309.477334 ],
116
- "mean_squared_error" : [152184.227218 ],
117
- "mean_squared_log_error" : [0.009524 ],
118
- "median_absolute_error" : [257.727777 ],
119
- "r2_score" : [0.764356 ],
120
- "explained_variance" : [0.764356 ],
121
- },
122
- dtype = "Float64" ,
123
- )
124
- evaluate_expected = evaluate_expected .reindex (
125
- index = evaluate_expected .index .astype ("Int64" )
126
- )
127
-
128
- pandas .testing .assert_frame_equal (
129
- evaluate_result , evaluate_expected , check_exact = False , rtol = 0.1
110
+ utils .check_pandas_df_schema_and_index (
111
+ evaluate_result , columns = eval_metrics , index = 1
130
112
)
131
113
132
114
# evaluate on all training data
133
115
evaluate_result = model .evaluate (df ).to_pandas ()
134
- pandas . testing . assert_frame_equal (
135
- evaluate_result , evaluate_expected , check_exact = False , rtol = 0. 1
116
+ utils . check_pandas_df_schema_and_index (
117
+ evaluate_result , columns = eval_metrics , index = 1
136
118
)
137
119
138
120
# predict new labels
139
121
predictions = model .predict (new_penguins_df ).to_pandas ()
140
- expected = pandas .DataFrame (
141
- {"predicted_body_mass_g" : [3968.8 , 3176.3 , 3545.2 ]},
142
- dtype = "Float64" ,
143
- index = pandas .Index ([1633 , 1672 , 1690 ], name = "tag_number" , dtype = "Int64" ),
144
- )
145
- pandas .testing .assert_frame_equal (
146
- predictions [["predicted_body_mass_g" ]], expected , check_exact = False , rtol = 0.1
122
+ utils .check_pandas_df_schema_and_index (
123
+ predictions ,
124
+ columns = ["predicted_body_mass_g" ],
125
+ index = [1633 , 1672 , 1690 ],
126
+ col_exact = False ,
147
127
)
148
128
149
129
new_name = f"{ dataset_id } .my_model"
@@ -168,24 +148,9 @@ def test_bqml_standalone_transform(penguins_df_default_index, new_penguins_df):
168
148
)
169
149
170
150
transformed = model .transform (new_penguins_df ).to_pandas ()
171
- expected = pandas .DataFrame (
172
- {
173
- "scaled_culmen_length_mm" : [- 0.8099 , - 0.9931 , - 1.103 ],
174
- "onehotencoded_species" : [
175
- [{"index" : 1 , "value" : 1.0 }],
176
- [{"index" : 1 , "value" : 1.0 }],
177
- [{"index" : 2 , "value" : 1.0 }],
178
- ],
179
- },
180
- index = pandas .Index ([1633 , 1672 , 1690 ], name = "tag_number" , dtype = "Int64" ),
181
- )
182
- expected ["scaled_culmen_length_mm" ] = expected ["scaled_culmen_length_mm" ].astype (
183
- "Float64"
184
- )
185
- pandas .testing .assert_frame_equal (
186
- transformed [["scaled_culmen_length_mm" , "onehotencoded_species" ]],
187
- expected ,
188
- check_exact = False ,
189
- rtol = 0.1 ,
190
- check_dtype = False ,
151
+ utils .check_pandas_df_schema_and_index (
152
+ transformed ,
153
+ columns = ["scaled_culmen_length_mm" , "onehotencoded_species" ],
154
+ index = [1633 , 1672 , 1690 ],
155
+ col_exact = False ,
191
156
)
0 commit comments