|
54 | 54 | },
|
55 | 55 | "outputs": [],
|
56 | 56 | "source": [
|
57 |
| - "import graphing # custom graphing code. See our GitHub repo for details\n", |
58 | 57 | "import numpy as np\n",
|
| 58 | + "import matplotlib.pyplot as plt\n", |
59 | 59 | "\n",
|
60 | 60 | "# Crime category\n",
|
61 |
| - "graphing.multiple_histogram(dataset, label_x='Category', label_group=\"Resolution\", histfunc='sum', show=True)\n", |
| 61 | + "dataset.groupby(['Category', 'Resolution']).size().unstack().plot(kind='bar', stacked=False)\n", |
| 62 | + "plt.xlabel('Category')\n", |
| 63 | + "plt.ylabel('Count')\n", |
| 64 | + "plt.title('Crimes by Category and Resolution')\n", |
| 65 | + "plt.tight_layout()\n", |
| 66 | + "plt.show()\n", |
62 | 67 | "\n",
|
63 | 68 | "# District\n",
|
64 |
| - "graphing.multiple_histogram(dataset, label_group=\"Resolution\", label_x=\"PdDistrict\", show=True)\n", |
| 69 | + "dataset.groupby(['PdDistrict', 'Resolution']).size().unstack().plot(kind='bar', stacked=False)\n", |
| 70 | + "plt.xlabel('Police District')\n", |
| 71 | + "plt.ylabel('Count')\n", |
| 72 | + "plt.title('Crimes by District and Resolution')\n", |
| 73 | + "plt.tight_layout()\n", |
| 74 | + "plt.show()\n", |
65 | 75 | "\n",
|
66 | 76 | "# Map of crimes\n",
|
67 |
| - "graphing.scatter_2D(dataset, label_x=\"X\", label_y=\"Y\", label_colour=\"Resolution\", title=\"GPS Coordinates\", size_multiplier=0.8, show=True)\n", |
| 77 | + "import seaborn as sns\n", |
68 | 78 | "\n",
|
69 |
| - "# Day of the week\n", |
70 |
| - "graphing.multiple_histogram(dataset, label_group=\"Resolution\", label_x=\"DayOfWeek\", show=True)\n", |
| 79 | + "plt.figure(figsize=(10, 6))\n", |
| 80 | + "sns.scatterplot(data=dataset, x='X', y='Y', hue='Resolution', alpha=0.6, s=8 * 0.8) # size_multiplier=0.8\n", |
| 81 | + "plt.title('GPS Coordinates')\n", |
| 82 | + "plt.xlabel('Longitude')\n", |
| 83 | + "plt.ylabel('Latitude')\n", |
| 84 | + "plt.legend(loc='best', title='Resolution')\n", |
| 85 | + "plt.tight_layout()\n", |
| 86 | + "plt.show()\n", |
71 | 87 | "\n",
|
72 |
| - "# day of the year\n", |
| 88 | + "# Day of the week\n", |
| 89 | + "dataset.groupby(['DayOfWeek', 'Resolution']).size().unstack().plot(kind='bar', stacked=False)\n", |
| 90 | + "plt.xlabel('Day of the Week')\n", |
| 91 | + "plt.ylabel('Count')\n", |
| 92 | + "plt.title('Crimes by Day of the Week and Resolution')\n", |
| 93 | + "plt.tight_layout()\n", |
| 94 | + "plt.show()\n", |
| 95 | + "\n", |
| 96 | + "# week of the year\n", |
73 | 97 | "# For graphing we simplify this to week or the graph becomes overwhelmed with bars\n",
|
74 | 98 | "dataset[\"week_of_year\"] = np.round(dataset.day_of_year / 7.0)\n",
|
75 |
| - "graphing.multiple_histogram(dataset, \n", |
76 |
| - " label_x='week_of_year',\n", |
77 |
| - " label_group='Resolution',\n", |
78 |
| - " histfunc='sum', show=True)\n", |
| 99 | + "\n", |
| 100 | + "dataset.groupby(['week_of_year', 'Resolution']).size().unstack().plot(kind='bar', stacked=False)\n", |
| 101 | + "plt.xlabel('Week of the Year')\n", |
| 102 | + "plt.ylabel('Count')\n", |
| 103 | + "plt.title('Crimes by Week and Resolution')\n", |
| 104 | + "plt.tight_layout()\n", |
| 105 | + "plt.show()\n", |
| 106 | + "\n", |
79 | 107 | "del dataset[\"week_of_year\"]"
|
80 | 108 | ]
|
81 | 109 | },
|
|
308 | 336 | "metadata": {},
|
309 | 337 | "outputs": [],
|
310 | 338 | "source": [
|
311 |
| - "# Temporarily shrink the training set to something\n", |
312 |
| - "# more realistic\n", |
| 339 | + "# Temporarily shrink the training set to 10000\n", |
| 340 | + "# for this exercise to see how pruning is important\n", |
| 341 | + "# even with moderately large datasets\n", |
313 | 342 | "full_training_set = train\n",
|
314 |
| - "train = train[:100]\n", |
| 343 | + "train = train[:10000]\n", |
315 | 344 | "\n",
|
316 |
| - "# fit the same tree as before\n", |
317 |
| - "model = sklearn.tree.DecisionTreeClassifier(random_state=1, max_depth=100)\n", |
318 | 345 | "\n",
|
319 |
| - "# Assess on the same test set as before\n", |
320 |
| - "train_accuracy, test_accuracy = fit_and_test_model(model)\n", |
321 |
| - "print(\"Train accuracy\", train_accuracy)\n", |
322 |
| - "print(\"Test accuracy\", test_accuracy)\n", |
| 346 | + "# Loop through the values below and build a model\n", |
| 347 | + "# each time, setting the maximum depth to that value \n", |
| 348 | + "max_depth_range = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,15, 20, 50, 100]\n", |
| 349 | + "accuracy_trainset = []\n", |
| 350 | + "accuracy_testset = []\n", |
| 351 | + "for depth in max_depth_range:\n", |
| 352 | + " # Create and fit the model\n", |
| 353 | + " prune_model = sklearn.tree.DecisionTreeClassifier(random_state=1, max_depth=depth)\n", |
323 | 354 | "\n",
|
324 |
| - "# Roll the training set back to the full set\n", |
| 355 | + " # Calculate and record its sensitivity\n", |
| 356 | + " train_accuracy, test_accuracy = fit_and_test_model(prune_model)\n", |
| 357 | + " accuracy_trainset.append(train_accuracy)\n", |
| 358 | + " accuracy_testset.append(test_accuracy)\n", |
| 359 | + "\n", |
| 360 | + "# Plot the sensitivity as a function of depth \n", |
| 361 | + "pruned_plot = pandas.DataFrame(dict(max_depth=max_depth_range, accuracy=accuracy_trainset))\n", |
| 362 | + "\n", |
| 363 | + "plt.figure(figsize=(10, 6))\n", |
| 364 | + "plt.plot(max_depth_range, accuracy_trainset, marker='o', label='Train Accuracy')\n", |
| 365 | + "plt.plot(max_depth_range, accuracy_testset, marker='s', label='Test Accuracy')\n", |
| 366 | + "\n", |
| 367 | + "plt.title('Model Accuracy vs. Decision Tree Depth')\n", |
| 368 | + "plt.xlabel('Max Depth')\n", |
| 369 | + "plt.ylabel('Accuracy')\n", |
| 370 | + "plt.xticks(max_depth_range)\n", |
| 371 | + "plt.legend()\n", |
| 372 | + "plt.grid(True)\n", |
| 373 | + "plt.tight_layout()\n", |
| 374 | + "plt.show()\n", |
| 375 | + "\n", |
| 376 | + "# Roll the training set back to the full thing\n", |
325 | 377 | "train = full_training_set"
|
326 | 378 | ]
|
327 | 379 | },
|
|
377 | 429 | "# Plot the sensitivity as a function of depth \n",
|
378 | 430 | "pruned_plot = pandas.DataFrame(dict(max_depth=max_depth_range, accuracy=accuracy_trainset))\n",
|
379 | 431 | "\n",
|
380 |
| - "fig = graphing.line_2D(dict(train=accuracy_trainset, test=accuracy_testset), x_range=max_depth_range, show=True)\n", |
| 432 | + "plt.figure(figsize=(10, 6))\n", |
| 433 | + "plt.plot(max_depth_range, accuracy_trainset, marker='o', label='Train Accuracy')\n", |
| 434 | + "plt.plot(max_depth_range, accuracy_testset, marker='s', label='Test Accuracy')\n", |
| 435 | + "\n", |
| 436 | + "plt.title('Model Accuracy vs. Decision Tree Depth')\n", |
| 437 | + "plt.xlabel('Max Depth')\n", |
| 438 | + "plt.ylabel('Accuracy')\n", |
| 439 | + "plt.xticks(max_depth_range)\n", |
| 440 | + "plt.legend()\n", |
| 441 | + "plt.grid(True)\n", |
| 442 | + "plt.tight_layout()\n", |
| 443 | + "plt.show()\n", |
381 | 444 | "\n",
|
382 | 445 | "# Roll the training set back to the full thing\n",
|
383 | 446 | "train = full_training_set"
|
|
0 commit comments