Skip to content

Commit 28d1061

Browse files
authored
Merge pull request #265 from Dekken/add-2d-linear-regression-example
Added a new example for 2d linear regression
2 parents b769855 + 04bdb8b commit 28d1061

File tree

1 file changed

+107
-0
lines changed

1 file changed

+107
-0
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""
2+
===================================
3+
Linear regression in two dimensions
4+
===================================
5+
6+
In this example, we try to predict the median price of houses in Boston's
7+
neighbors by looking at two features: the average of the number of rooms per
8+
dwelling and the pencentage of low status in the population.
9+
The linear regression is done using the `Boston Housing Dataset`_ which
10+
contains 13 features used to predict the median price of houses.
11+
The two features selected are the most efficent on the test set.
12+
The 3D representation allows a better understanding of the prediction
13+
mechanism with two features.
14+
15+
This example is inspired by linear regression example from
16+
`scikit-learn documentation`_.
17+
18+
.. _Boston Housing Dataset: https://www.kaggle.com/c/boston-housing
19+
.. _scikit-learn documentation: http://scikit-learn.org/stable/auto_examples/linear_model/plot_ols.html#sphx-glr-auto-examples-linear-model-plot-ols-py
20+
"""
21+
22+
import matplotlib.pyplot as plt
23+
import numpy as np
24+
from tick import linear_model
25+
from sklearn.metrics import mean_squared_error, r2_score
26+
from sklearn.utils import shuffle
27+
from sklearn.datasets import load_boston
28+
from mpl_toolkits.mplot3d import axes3d
29+
from matplotlib import cm
30+
31+
# Load the Boston Housing Dataset
32+
features, label = load_boston(return_X_y=True)
33+
features, label = shuffle(features, label, random_state=0)
34+
35+
# Use two features: the average of the number of rooms per dwelling and
36+
# the pencentage of low status of the population
37+
X = features[:, [5, 12]]
38+
39+
# Split the data into training/testing sets
40+
n_train_data = int(0.8 * X.shape[0])
41+
42+
X_train = X[:n_train_data]
43+
X_test = X[n_train_data:]
44+
45+
y_train = label[:n_train_data]
46+
y_test = label[n_train_data:]
47+
48+
# Create linear regression and fit it on the training set
49+
regr = linear_model.LinearRegression()
50+
regr.fit(X_train, y_train)
51+
52+
# Make predictions using the testing set
53+
y_pred = regr.predict(X_test)
54+
55+
print('Coefficients:')
56+
print(' intercept: {:.2f}'.format(regr.intercept))
57+
print(' average room per dwelling: {:.2f}'.format(regr.weights[0]))
58+
print(' percentage of low status in population: {:.2f}'
59+
.format(regr.weights[1]))
60+
61+
# The mean squared error
62+
print('Mean squared error on test set: {:.2f}'.format(
63+
mean_squared_error(y_test, y_pred)))
64+
65+
# Explained variance score: 1 is perfect prediction
66+
print('Variance score on test set: {:.2f}'.format(r2_score(y_test, y_pred)))
67+
# To work in 3D
68+
69+
# We first generate a mesh grid
70+
resolution = 10
71+
x = X_test[:, 0]
72+
y = X_test[:, 1]
73+
z = y_test
74+
75+
x_surf = np.linspace(min(x), max(x), resolution)
76+
y_surf = np.linspace(min(y), max(y), resolution)
77+
x_surf, y_surf = np.meshgrid(x_surf, y_surf)
78+
79+
# and then predict the label for all values in the grid
80+
z_surf = np.zeros_like(x_surf)
81+
mesh_points = np.vstack((x_surf.ravel(), y_surf.ravel())).T
82+
z_surf.ravel()[:] = regr.predict(mesh_points)
83+
84+
fig = plt.figure(figsize=(20, 5))
85+
86+
# 3D representation under different rotated angles for a better visualazion
87+
xy_angles = [10, 35, 60, 85]
88+
z_angle = 20
89+
90+
for i, angle in enumerate(xy_angles):
91+
n_columns = len(xy_angles)
92+
position = i + 1
93+
94+
ax = fig.add_subplot(1, n_columns, position, projection='3d')
95+
96+
ax.view_init(z_angle, angle)
97+
98+
ax.plot_surface(x_surf, y_surf, z_surf, cmap=cm.hot, rstride=1, cstride=1,
99+
alpha=0.3, linewidth=0.2, edgecolors='black')
100+
ax.scatter(x, y, z)
101+
102+
ax.set_title('angle: {}°'.format(angle))
103+
ax.set_zlabel('median house pricing')
104+
ax.set_xlabel('avg room per dwelling')
105+
ax.set_ylabel('% low status population')
106+
107+
plt.show()

0 commit comments

Comments
 (0)