diff --git a/lectures/kalman_2.md b/lectures/kalman_2.md index a1eb32490..5563ca11c 100644 --- a/lectures/kalman_2.md +++ b/lectures/kalman_2.md @@ -447,9 +447,9 @@ def simulate_workers(worker, T, ax, mu_0=None, Sigma_0=None, A, C, G, R = worker.A, worker.C, worker.G, worker.R xhat_0, Σ_0 = worker.xhat_0, worker.Σ_0 - if isinstance(mu_0, type(None)): + if mu_0 is None: mu_0 = xhat_0 - if isinstance(Sigma_0, type(None)): + if Sigma_0 is None: Sigma_0 = worker.Σ_0 ss = LinearStateSpace(A, C, G, np.sqrt(R), @@ -471,12 +471,12 @@ def simulate_workers(worker, T, ax, mu_0=None, Sigma_0=None, kalman.update(y[i]) x_hat, Σ = kalman.x_hat, kalman.Sigma Σ_t.append(Σ) - [y_hat_t[i]] = worker.G @ x_hat - [u_hat_t[i]] = x_hat[1] + y_hat_t[i] = (worker.G @ x_hat).item() + u_hat_t[i] = x_hat[1].item() - if diff == True: + if diff : title = ('Difference between inferred and true work ethic over time' - if title == None else title) + if title is None else title) ax.plot(u_hat_t - u_0, alpha=.5) ax.axhline(y=0, color='grey', linestyle='dashed') @@ -485,10 +485,10 @@ def simulate_workers(worker, T, ax, mu_0=None, Sigma_0=None, ax.set_title(title) else: - label_line = (r'$E[u_t|y^{t-1}]$' if name == None + label_line = (r'$E[u_t|y^{t-1}]$' if name is None else name) title = ('Inferred work ethic over time' - if title == None else title) + if title is None else title) u_hat_plot = ax.plot(u_hat_t, label=label_line) ax.axhline(y=u_0, color=u_hat_plot[0].get_color(),