|
50 | 50 | "outputs": [],
|
51 | 51 | "source": [
|
52 | 52 | "import statsmodels.formula.api as smf\n",
|
| 53 | + "import matplotlib.pyplot as plt\n", |
| 54 | + "import numpy as np\n", |
53 | 55 | "\n",
|
54 | 56 | "# First, we define our formula using a special syntax\n",
|
55 | 57 | "# This says that rescues_last_year is explained by weight_last_year\n",
|
56 | 58 | "formula = \"rescues_last_year ~ weight_last_year\"\n",
|
57 | 59 | "\n",
|
58 | 60 | "model = smf.ols(formula = formula, data = data).fit()\n",
|
59 | 61 | "\n",
|
| 62 | + "# Extract x and y values\n", |
60 | 63 | "x = data[\"weight_last_year\"]\n",
|
61 | 64 | "y = data[\"rescues_last_year\"]\n",
|
62 | 65 | "\n",
|
63 |
| - "# Create scatter plot and trendline\n", |
| 66 | + "# Scatter plot of the data and trendline\n", |
64 | 67 | "plt.figure(figsize=(8, 6))\n",
|
65 |
| - "plt.scatter(x, y, label=\"Data\", alpha=0.7)\n", |
| 68 | + "plt.scatter(x, y, alpha=0.7, label=\"Data\")\n", |
66 | 69 | "\n",
|
67 | 70 | "x_vals = np.linspace(x.min(), x.max(), 100)\n",
|
68 |
| - "x_vals_df = pandas.DataFrame({\"month_old_when_trained\": x_vals})\n", |
69 |
| - "y_preds = model_norm.predict(x_vals_df)\n", |
| 71 | + "y_vals = model.params[1] * x_vals + model.params[0] # Slope * x + Intercept\n", |
70 | 72 | "\n",
|
71 |
| - "plt.plot(x_vals, y_preds, color=\"red\", label=\"Trendline\")\n", |
| 73 | + "plt.plot(x_vals, y_vals, color=\"red\", label=\"Trendline (Linear Regression)\")\n", |
72 | 74 | "\n",
|
73 |
| - "plt.xlabel(\"Weight last year\")\n", |
74 |
| - "plt.ylabel(\"Mean Rescues Per Year\")\n", |
75 |
| - "plt.title(\"Rescues last eyar\")\n", |
| 75 | + "plt.xlabel(\"Weight Last Year\")\n", |
| 76 | + "plt.ylabel(\"Rescues Last Year\")\n", |
| 77 | + "plt.title(\"Rescues vs Weight with Linear Trendline\")\n", |
76 | 78 | "plt.legend()\n",
|
77 | 79 | "plt.grid(True)\n",
|
78 | 80 | "plt.tight_layout()\n",
|
|
0 commit comments