|
198 | 198 | "source": [ |
199 | 199 | "## Tuning using a grid-search\n", |
200 | 200 | "\n", |
201 | | - "In the previous exercise we used one `for` loop for each hyperparameter to\n", |
202 | | - "find the best combination over a fixed grid of values. `GridSearchCV` is a\n", |
203 | | - "scikit-learn class that implements a very similar logic with less repetitive\n", |
204 | | - "code.\n", |
| 201 | + "In the previous exercise (M3.01) we used two nested `for` loops (one for each\n", |
| 202 | + "hyperparameter) to test different combinations over a fixed grid of\n", |
| 203 | + "hyperparameter values. In each iteration of the loop, we used\n", |
| 204 | + "`cross_val_score` to compute the mean score (as averaged across\n", |
| 205 | + "cross-validation splits), and compared those mean scores to select the best\n", |
| 206 | + "combination. `GridSearchCV` is a scikit-learn class that implements a very\n", |
| 207 | + "similar logic with less repetitive code. The suffix `CV` refers to the\n", |
| 208 | + "cross-validation it runs internally (instead of the `cross_val_score` we\n", |
| 209 | + "\"hard\" coded).\n", |
| 210 | + "\n", |
| 211 | + "The `GridSearchCV` estimator takes a `param_grid` parameter which defines all\n", |
| 212 | + "hyperparameters and their associated values. The grid-search is in charge of\n", |
| 213 | + "creating all possible combinations and testing them.\n", |
| 214 | + "\n", |
| 215 | + "The number of combinations is equal to the product of the number of values to\n", |
| 216 | + "explore for each parameter. Thus, adding new parameters with their associated\n", |
| 217 | + "values to be explored rapidly becomes computationally expensive. Because of\n", |
| 218 | + "that, here we only explore the combination learning-rate and the maximum\n", |
| 219 | + "number of nodes for a total of 4 x 3 = 12 combinations.\n", |
205 | 220 | "\n", |
206 | | - "Let's see how to use the `GridSearchCV` estimator for doing such search. Since\n", |
207 | | - "the grid-search is costly, we only explore the combination learning-rate and\n", |
208 | | - "the maximum number of nodes." |
209 | | - ] |
210 | | - }, |
211 | | - { |
212 | | - "cell_type": "code", |
213 | | - "execution_count": null, |
214 | | - "metadata": {}, |
215 | | - "outputs": [], |
216 | | - "source": [ |
217 | 221 | "%%time\n", |
218 | 222 | "from sklearn.model_selection import GridSearchCV\n", |
219 | 223 | "\n", |
220 | 224 | "param_grid = {\n", |
221 | | - " \"classifier__learning_rate\": (0.01, 0.1, 1, 10),\n", |
222 | | - " \"classifier__max_leaf_nodes\": (3, 10, 30),\n", |
223 | | - "}\n", |
| 225 | + " \"classifier__learning_rate\": (0.01, 0.1, 1, 10), # 4 possible values\n", |
| 226 | + " \"classifier__max_leaf_nodes\": (3, 10, 30), # 3 possible values\n", |
| 227 | + "} # 12 unique combinations\n", |
224 | 228 | "model_grid_search = GridSearchCV(model, param_grid=param_grid, n_jobs=2, cv=2)\n", |
225 | 229 | "model_grid_search.fit(data_train, target_train)" |
226 | 230 | ] |
|
229 | 233 | "cell_type": "markdown", |
230 | 234 | "metadata": {}, |
231 | 235 | "source": [ |
232 | | - "Finally, we check the accuracy of our model using the test set." |
| 236 | + "You can access the best combination of hyperparameters found by the grid\n", |
| 237 | + "search using the `best_params_` attribute." |
233 | 238 | ] |
234 | 239 | }, |
235 | 240 | { |
|
238 | 243 | "metadata": {}, |
239 | 244 | "outputs": [], |
240 | 245 | "source": [ |
241 | | - "accuracy = model_grid_search.score(data_test, target_test)\n", |
242 | | - "print(\n", |
243 | | - " f\"The test accuracy score of the grid-searched pipeline is: {accuracy:.2f}\"\n", |
244 | | - ")" |
245 | | - ] |
246 | | - }, |
247 | | - { |
248 | | - "cell_type": "markdown", |
249 | | - "metadata": {}, |
250 | | - "source": [ |
251 | | - "<div class=\"admonition warning alert alert-danger\">\n", |
252 | | - "<p class=\"first admonition-title\" style=\"font-weight: bold;\">Warning</p>\n", |
253 | | - "<p>Be aware that the evaluation should normally be performed through\n", |
254 | | - "cross-validation by providing <tt class=\"docutils literal\">model_grid_search</tt> as a model to the\n", |
255 | | - "<tt class=\"docutils literal\">cross_validate</tt> function.</p>\n", |
256 | | - "<p class=\"last\">Here, we used a single train-test split to evaluate <tt class=\"docutils literal\">model_grid_search</tt>. In\n", |
257 | | - "a future notebook will go into more detail about nested cross-validation, when\n", |
258 | | - "you use cross-validation both for hyperparameter tuning and model evaluation.</p>\n", |
259 | | - "</div>" |
| 246 | + "print(f\"The best set of parameters is: {model_grid_search.best_params_}\")" |
260 | 247 | ] |
261 | 248 | }, |
262 | 249 | { |
263 | 250 | "cell_type": "markdown", |
264 | 251 | "metadata": {}, |
265 | 252 | "source": [ |
266 | | - "The `GridSearchCV` estimator takes a `param_grid` parameter which defines all\n", |
267 | | - "hyperparameters and their associated values. The grid-search is in charge\n", |
268 | | - "of creating all possible combinations and test them.\n", |
269 | | - "\n", |
270 | | - "The number of combinations are equal to the product of the number of values to\n", |
271 | | - "explore for each parameter (e.g. in our example 4 x 3 combinations). Thus,\n", |
272 | | - "adding new parameters with their associated values to be explored become\n", |
273 | | - "rapidly computationally expensive.\n", |
274 | | - "\n", |
275 | | - "Once the grid-search is fitted, it can be used as any other predictor by\n", |
276 | | - "calling `predict` and `predict_proba`. Internally, it uses the model with the\n", |
| 253 | + "Once the grid-search is fitted, it can be used as any other estimator, i.e. it\n", |
| 254 | + "has `predict` and `score` methods. Internally, it uses the model with the\n", |
277 | 255 | "best parameters found during `fit`.\n", |
278 | 256 | "\n", |
279 | | - "Get predictions for the 5 first samples using the estimator with the best\n", |
280 | | - "parameters." |
| 257 | + "Let's get the predictions for the 5 first samples using the estimator with the\n", |
| 258 | + "best parameters:" |
281 | 259 | ] |
282 | 260 | }, |
283 | 261 | { |
|
293 | 271 | "cell_type": "markdown", |
294 | 272 | "metadata": {}, |
295 | 273 | "source": [ |
296 | | - "You can know about these parameters by looking at the `best_params_`\n", |
297 | | - "attribute." |
| 274 | + "Finally, we check the accuracy of our model using the test set." |
298 | 275 | ] |
299 | 276 | }, |
300 | 277 | { |
|
303 | 280 | "metadata": {}, |
304 | 281 | "outputs": [], |
305 | 282 | "source": [ |
306 | | - "print(f\"The best set of parameters is: {model_grid_search.best_params_}\")" |
| 283 | + "accuracy = model_grid_search.score(data_test, target_test)\n", |
| 284 | + "print(\n", |
| 285 | + " f\"The test accuracy score of the grid-search pipeline is: {accuracy:.2f}\"\n", |
| 286 | + ")" |
307 | 287 | ] |
308 | 288 | }, |
309 | 289 | { |
310 | 290 | "cell_type": "markdown", |
311 | 291 | "metadata": {}, |
312 | 292 | "source": [ |
313 | | - "The accuracy and the best parameters of the grid-searched pipeline are similar\n", |
| 293 | + "The accuracy and the best parameters of the grid-search pipeline are similar\n", |
314 | 294 | "to the ones we found in the previous exercise, where we searched the best\n", |
315 | | - "parameters \"by hand\" through a double for loop.\n", |
| 295 | + "parameters \"by hand\" through a double `for` loop.\n", |
| 296 | + "\n", |
| 297 | + "## The need for a validation set\n", |
| 298 | + "\n", |
| 299 | + "In the previous section, the selection of the best hyperparameters was done\n", |
| 300 | + "using the train set, coming from the initial train-test split. Then, we\n", |
| 301 | + "evaluated the generalization performance of our tuned model on the left out\n", |
| 302 | + "test set. This can be shown schematically as follows:\n", |
| 303 | + "\n", |
| 304 | + "\n", |
| 306 | + "\n", |
| 307 | + "<div class=\"admonition note alert alert-info\">\n", |
| 308 | + "<p class=\"first admonition-title\" style=\"font-weight: bold;\">Note</p>\n", |
| 309 | + "<p>This figure shows the particular case of <strong>K-fold</strong> cross-validation strategy\n", |
| 310 | + "using <tt class=\"docutils literal\">n_splits=5</tt> to further split the train set coming from a train-test\n", |
| 311 | + "split. For each cross-validation split, the procedure trains a model on all\n", |
| 312 | + "the red samples, evaluates the score of a given set of hyperparameters on the\n", |
| 313 | + "green samples. The best combination of hyperparameters <tt class=\"docutils literal\">best_params</tt> is selected\n", |
| 314 | + "based on those intermediate scores.</p>\n", |
| 315 | + "<p>Then a final model is refitted using <tt class=\"docutils literal\">best_params</tt> on the concatenation of the\n", |
| 316 | + "red and green samples and evaluated on the blue samples.</p>\n", |
| 317 | + "<p class=\"last\">The green samples are sometimes referred as the <strong>validation set</strong> to\n", |
| 318 | + "differentiate them from the final test set in blue.</p>\n", |
| 319 | + "</div>\n", |
316 | 320 | "\n", |
317 | 321 | "In addition, we can inspect all results which are stored in the attribute\n", |
318 | 322 | "`cv_results_` of the grid-search. We filter some specific columns from these\n", |
|
0 commit comments