|
27 | 27 | "!pip install statsmodels\n",
|
28 | 28 | "!wget https://raw.githubusercontent.com/MicrosoftDocs/mslearn-introduction-to-machine-learning/main/graphing.py\n",
|
29 | 29 | "!wget https://raw.githubusercontent.com/MicrosoftDocs/mslearn-introduction-to-machine-learning/main/Data/avalanche.csv\n",
|
30 |
| - "import graphing # custom graphing code. See our GitHub repo for details\n", |
31 | 30 | "\n",
|
32 | 31 | "#Import the data from the .csv file\n",
|
33 | 32 | "dataset = pandas.read_csv('avalanche.csv', delimiter=\"\\t\", index_col=0)\n",
|
|
254 | 253 | "metadata": {},
|
255 | 254 | "outputs": [],
|
256 | 255 | "source": [
|
257 |
| - "graphing.model_to_surface_plot(model_with_interaction, [\"weak_layers\", \"wind\"], test)" |
| 256 | + "import matplotlib.pyplot as plt\n", |
| 257 | + "import numpy as np\n", |
| 258 | + "import pandas as pd\n", |
| 259 | + "from mpl_toolkits.mplot3d import Axes3D\n", |
| 260 | + "\n", |
| 261 | + "def predict(weak_layers, wind, surface_hoar=0, fresh_thickness=0, no_visitors=0):\n", |
| 262 | + " return model_with_interaction.predict(pd.DataFrame({\n", |
| 263 | + " \"weak_layers\": weak_layers,\n", |
| 264 | + " \"wind\": wind,\n", |
| 265 | + " \"surface_hoar\": surface_hoar,\n", |
| 266 | + " \"fresh_thickness\": fresh_thickness,\n", |
| 267 | + " \"no_visitors\": no_visitors\n", |
| 268 | + " }))\n", |
| 269 | + "\n", |
| 270 | + "# Generate a graph for weak_layers and wind\n", |
| 271 | + "weak_layers = np.linspace(-20, 20, 100)\n", |
| 272 | + "wind = np.linspace(-20, 20, 100)\n", |
| 273 | + "weak_layers_grid, wind_grid = np.meshgrid(weak_layers, wind)\n", |
| 274 | + "\n", |
| 275 | + "predicted_values = predict(weak_layers_grid.ravel(), wind_grid.ravel()).to_numpy().reshape(weak_layers_grid.shape)\n", |
| 276 | + "\n", |
| 277 | + "fig = plt.figure()\n", |
| 278 | + "ax = fig.add_subplot(111, projection='3d')\n", |
| 279 | + "ax.plot_surface(weak_layers_grid, wind_grid, predicted_values, cmap='viridis')\n", |
| 280 | + "\n", |
| 281 | + "ax.set_xlabel(\"Weak Layers\")\n", |
| 282 | + "ax.set_ylabel(\"Wind\")\n", |
| 283 | + "ax.set_zlabel(\"Predicted Values\")\n", |
| 284 | + "\n", |
| 285 | + "plt.show()" |
258 | 286 | ]
|
259 | 287 | },
|
260 | 288 | {
|
261 | 289 | "cell_type": "markdown",
|
262 | 290 | "id": "1d8af2ee",
|
263 | 291 | "metadata": {},
|
264 | 292 | "source": [
|
265 |
| - "The graph is interactive - rotate it and explore how there's a clear s-shaped relationship between the features and probability.\n", |
| 293 | + "There's now a clear s-shaped relationship between the features and probability.\n", |
266 | 294 | "\n",
|
267 | 295 | "Let's now look at the features that we've said can interact:"
|
268 | 296 | ]
|
|
274 | 302 | "metadata": {},
|
275 | 303 | "outputs": [],
|
276 | 304 | "source": [
|
277 |
| - "graphing.model_to_surface_plot(model_with_interaction, [\"no_visitors\", \"fresh_thickness\"], test)" |
| 305 | + "import matplotlib.pyplot as plt\n", |
| 306 | + "import numpy as np\n", |
| 307 | + "import pandas as pd\n", |
| 308 | + "from mpl_toolkits.mplot3d import Axes3D\n", |
| 309 | + "\n", |
| 310 | + "def predict(no_visitors, fresh_thickness, weak_layers=0, wind=0, surface_hoar=0):\n", |
| 311 | + " return model_with_interaction.predict(pd.DataFrame({\n", |
| 312 | + " \"no_visitors\": no_visitors,\n", |
| 313 | + " \"fresh_thickness\": fresh_thickness,\n", |
| 314 | + " \"weak_layers\": weak_layers,\n", |
| 315 | + " \"wind\": wind,\n", |
| 316 | + " \"surface_hoar\": surface_hoar\n", |
| 317 | + " }))\n", |
| 318 | + "\n", |
| 319 | + "# Generate the graph\n", |
| 320 | + "no_visitors = np.linspace(-20, 20, 100)\n", |
| 321 | + "fresh_thickness = np.linspace(-20, 20, 100)\n", |
| 322 | + "no_visitors_grid, fresh_thickness_grid = np.meshgrid(no_visitors, fresh_thickness)\n", |
| 323 | + "\n", |
| 324 | + "predicted_values = predict(no_visitors_grid.ravel(), fresh_thickness_grid.ravel()).to_numpy().reshape(no_visitors_grid.shape)\n", |
| 325 | + "\n", |
| 326 | + "fig = plt.figure()\n", |
| 327 | + "ax = fig.add_subplot(111, projection='3d')\n", |
| 328 | + "ax.plot_surface(no_visitors_grid, fresh_thickness_grid, predicted_values, cmap='viridis')\n", |
| 329 | + "\n", |
| 330 | + "ax.set_xlabel(\"No Visitors\")\n", |
| 331 | + "ax.set_ylabel(\"Fresh Thickness\")\n", |
| 332 | + "ax.set_zlabel(\"Predicted Values\")\n", |
| 333 | + "\n", |
| 334 | + "plt.show()" |
278 | 335 | ]
|
279 | 336 | },
|
280 | 337 | {
|
|
0 commit comments