|
1 | 1 | --- |
2 | | -sidebar_position: 4 |
| 2 | +sidebar_position: 3 |
3 | 3 | title: 线性回归 |
4 | 4 | --- |
5 | 5 |
|
@@ -29,6 +29,70 @@ title: 线性回归 |
29 | 29 |
|
30 | 30 | 我们预期中,理想效果应该是 0、0 好于 -4、4 好于 7、1。只有均方误差正确的反应了这一点。 |
31 | 31 |
|
| 32 | +通过误差的大小,我们可以慢慢修正我们的参数让线性拟合更好,导数可以反应数据变化的趋势,所以我们可以求导来修改参数。 |
| 33 | + |
| 34 | +```python showLineNumbers |
| 35 | +import numpy as np |
| 36 | +from matplotlib import pyplot as plt |
| 37 | + |
| 38 | + |
| 39 | +class Line: |
| 40 | + def __init__(self, data): |
| 41 | + self.w = 1 |
| 42 | + self.b = 0 |
| 43 | + self.learning_rate = 0.01 |
| 44 | + self.fig, (self.ax1, self.ax2) = plt.subplots(2, 1) |
| 45 | + self.loss_list = [] |
| 46 | + |
| 47 | + def get_data(self, data): |
| 48 | + self.X = np.array(data)[:, 0] |
| 49 | + self.y = np.array(data)[:, 1] |
| 50 | + |
| 51 | + def predict(self, x): |
| 52 | + return self.w * x + self.b |
| 53 | + |
| 54 | + def train(self, epoch_times): |
| 55 | + for epoch in range(epoch_times): |
| 56 | + total_loss = 0 |
| 57 | + for x, y in zip(self.X, self.y): |
| 58 | + y_pred = self.predict(x) |
| 59 | + # Calculate gradients |
| 60 | + gradient_w = -2 * x * (y - y_pred) |
| 61 | + gradient_b = -2 * (y - y_pred) |
| 62 | + # Update weights |
| 63 | + self.w -= self.learning_rate * gradient_w |
| 64 | + self.b -= self.learning_rate * gradient_b |
| 65 | + # Calculate loss |
| 66 | + loss = (y - y_pred) ** 2 |
| 67 | + total_loss += loss |
| 68 | + epoch_loss = total_loss / len(self.X) |
| 69 | + self.loss_list.append(epoch_loss) |
| 70 | + if epoch % 10 == 0: |
| 71 | + print(f"loss: {epoch_loss}") |
| 72 | + self.plot() |
| 73 | + plt.ioff() |
| 74 | + plt.show() |
| 75 | + |
| 76 | + def plot(self): |
| 77 | + plt.ion() # Enable interactive mode |
| 78 | + self.ax2.clear() |
| 79 | + self.ax1.clear() |
| 80 | + x = np.linspace(0, 10, 100) |
| 81 | + self.ax1.scatter(self.X, self.y, c="g") |
| 82 | + self.ax1.plot(x, self.predict(x), c="b") |
| 83 | + self.ax2.plot(list(range(len(self.loss_list))), self.loss_list) |
| 84 | + plt.show() |
| 85 | + plt.pause(0.1) |
| 86 | + |
| 87 | +if __name__ == "__main__": |
| 88 | + # Input data |
| 89 | + data = [(1, 1), (1.8, 2), (2.5, 3), (4.2, 4), (5, 5), (6, 6), (7, 7)] |
| 90 | + s = Line(data) |
| 91 | + s.get_data(data) |
| 92 | + s.train(100) |
| 93 | +``` |
| 94 | + |
| 95 | +## 使用sklearn模块完成 |
32 | 96 |
|
33 | 97 | ```python showLineNumbers |
34 | 98 | import numpy as np |
|
0 commit comments