|
2 | 2 | # |
3 | 3 | import matplotlib.pyplot as plt |
4 | 4 | import numpy as np |
| 5 | +from scipy.stats import multivariate_normal |
5 | 6 |
|
6 | 7 | from helpers import assert_equality |
7 | 8 |
|
8 | 9 |
|
9 | 10 | def plot(): |
10 | | - delta = 0.8 |
11 | | - x = y = np.arange(-3.0, 3.01, delta) |
12 | | - X, Y = np.meshgrid(x, y) |
13 | | - Z1 = plt.mlab.bivariate_normal(X, Y, 1.0, 1.0, 0.0, 0.0) |
14 | | - Z2 = plt.mlab.bivariate_normal(X, Y, 1.5, 0.5, 1, 1) |
15 | | - Z = 10 * (Z1 - Z2) |
16 | | - levels = [-2, -1.5, -1.2, -0.9, -0.6, -0.3, 0.0, 0.3, 0.6, 0.9, 1.2, 1.5] |
| 11 | + mean = np.array([1, 1]) |
| 12 | + cov = np.eye(2) |
| 13 | + nbins = 5 |
| 14 | + |
17 | 15 | fig = plt.figure() |
18 | | - CS = plt.contourf(X, Y, Z, 10, levels=levels) |
19 | | - plt.contour(CS, levels=levels, colors="r") |
| 16 | + ax = plt.gca() |
| 17 | + |
| 18 | + x_max = 2 |
| 19 | + x_min = 0 |
| 20 | + y_max = 2 |
| 21 | + y_min = 0 |
| 22 | + |
| 23 | + xi, yi = np.mgrid[x_min:x_max:nbins * 1j, y_min:y_max:nbins * 1j] |
| 24 | + pos = np.empty(xi.shape + (2,)) |
| 25 | + pos[:, :, 0] = xi |
| 26 | + pos[:, :, 1] = yi |
| 27 | + zi = multivariate_normal(mean, cov, allow_singular=True, seed=0).pdf(pos) |
| 28 | + ax.contourf(xi, yi, zi, 250) |
| 29 | + |
| 30 | + ax.set_xlim(x_min, x_max) |
| 31 | + ax.set_ylim(y_min, y_max) |
20 | 32 | return fig |
21 | 33 |
|
22 | 34 |
|
|
0 commit comments