Skip to content

Commit 8a86738

Browse files
author
gitName
committed
tweak code
1 parent 4ce45ce commit 8a86738

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

learn-pr/azure/test-machine-learning-models/notebooks/5-3-exercise-feature-normalization.ipynb

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@
5757
"import numpy as np\n",
5858
"\n",
5959
"# Train model using Gradient Descent\n",
60-
"# This method uses custom code that will print out progress as training advances.\n",
61-
"# You don't need to inspect how this works for these exercises, but if you are\n",
62-
"# curious, you can find it in out GitHub repository\n",
6360
"model = gradient_descent(data.month_old_when_trained, data.mean_rescues_per_year, learning_rate=5E-4, number_of_iterations=8000)\n"
6461
]
6562
},

learn-pr/azure/test-machine-learning-models/notebooks/5-5-exercise-test-training-datasets.ipynb

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,29 +50,31 @@
5050
"outputs": [],
5151
"source": [
5252
"import statsmodels.formula.api as smf\n",
53+
"import matplotlib.pyplot as plt\n",
54+
"import numpy as np\n",
5355
"\n",
5456
"# First, we define our formula using a special syntax\n",
5557
"# This says that rescues_last_year is explained by weight_last_year\n",
5658
"formula = \"rescues_last_year ~ weight_last_year\"\n",
5759
"\n",
5860
"model = smf.ols(formula = formula, data = data).fit()\n",
5961
"\n",
62+
"# Extract x and y values\n",
6063
"x = data[\"weight_last_year\"]\n",
6164
"y = data[\"rescues_last_year\"]\n",
6265
"\n",
63-
"# Create scatter plot and trendline\n",
66+
"# Scatter plot of the data and trendline\n",
6467
"plt.figure(figsize=(8, 6))\n",
65-
"plt.scatter(x, y, label=\"Data\", alpha=0.7)\n",
68+
"plt.scatter(x, y, alpha=0.7, label=\"Data\")\n",
6669
"\n",
6770
"x_vals = np.linspace(x.min(), x.max(), 100)\n",
68-
"x_vals_df = pandas.DataFrame({\"month_old_when_trained\": x_vals})\n",
69-
"y_preds = model_norm.predict(x_vals_df)\n",
71+
"y_vals = model.params[1] * x_vals + model.params[0] # Slope * x + Intercept\n",
7072
"\n",
71-
"plt.plot(x_vals, y_preds, color=\"red\", label=\"Trendline\")\n",
73+
"plt.plot(x_vals, y_vals, color=\"red\", label=\"Trendline (Linear Regression)\")\n",
7274
"\n",
73-
"plt.xlabel(\"Weight last year\")\n",
74-
"plt.ylabel(\"Mean Rescues Per Year\")\n",
75-
"plt.title(\"Rescues last eyar\")\n",
75+
"plt.xlabel(\"Weight Last Year\")\n",
76+
"plt.ylabel(\"Rescues Last Year\")\n",
77+
"plt.title(\"Rescues vs Weight with Linear Trendline\")\n",
7678
"plt.legend()\n",
7779
"plt.grid(True)\n",
7880
"plt.tight_layout()\n",

0 commit comments

Comments
 (0)