|
53 | 53 | for trial in range(n_trials): |
54 | 54 | u.append(np.zeros((n_u, n_samp))) |
55 | 55 | for k in range(1, n_samp): |
56 | | - u_t = 0.975*u[trial][:,k-1] + 1e-1*np.random.normal(size=n_u) |
| 56 | + u_t = 0.975 * u[trial][:, k - 1] + 1e-1 * np.random.normal(size=n_u) |
57 | 57 | u[trial][:, k] = u_t |
58 | 58 |
|
59 | 59 | y, x, z = sys.simulate_block(u) |
60 | 60 |
|
61 | | -n_samp_imp = int(np.ceil(0.1/dt)) |
| 61 | +n_samp_imp = int(np.ceil(0.1 / dt)) |
62 | 62 | y_imp = sys.simulate_imp(n_samp_imp) |
63 | 63 | t_imp = np.arange(0, n_samp_imp * dt, dt) |
64 | 64 |
|
|
77 | 77 |
|
78 | 78 | # compare fit to original without state noise |
79 | 79 | sys_hat = plds.System(fit) |
80 | | -sys_hat.Q = np.zeros_like(sys_hat.Q) |
81 | 80 | y_hat, x_hat, _ = sys_hat.simulate_block(u) |
82 | 81 | y_imp_hat = sys_hat.simulate_imp(n_samp_imp) |
83 | 82 |
|
|
87 | 86 |
|
88 | 87 | fig, axs = plt.subplots(1, 2) |
89 | 88 | axs[0].semilogy(sing_vals[:n_h], "-o", color=[0.5, 0.5, 0.5]) |
90 | | -axs[0].semilogy(sing_vals[:n_h], color='k', linewidth=2) |
| 89 | +axs[0].semilogy(sing_vals[:n_h], color="k", linewidth=2) |
91 | 90 | axs[0].set(ylabel="Singular Values", xlabel="Singular Value Index") |
92 | 91 |
|
93 | 92 | l1 = axs[1].plot(t_imp, y_imp[0].T, "-", c="k", linewidth=2) |
|
112 | 111 | axs[0].plot(t, y[eg_trial][0, :] / dt, "k-") |
113 | 112 | axs[0].plot(t, y_hat[eg_trial][0, :] / dt, "-", c="gray", linewidth=2) |
114 | 113 | axs[0].legend(["measurement", "fit"]) |
115 | | -axs[0].set(ylabel="Output 1 (events/s)", xlabel="Time (s)", title=f"proportion var explained (training): {pve[0]:0.3f}") |
| 114 | +axs[0].set( |
| 115 | + ylabel="Output 1 (events/s)", |
| 116 | + xlabel="Time (s)", |
| 117 | + title=f"proportion var explained (training): {pve[0]:0.3f}", |
| 118 | +) |
116 | 119 |
|
117 | 120 | axs[1].plot(t, y[eg_trial][1, :] / dt, "k-") |
118 | 121 | axs[1].plot(t, y_hat[eg_trial][1, :] / dt, "-", c="gray", linewidth=2) |
119 | | -axs[1].set(ylabel="Output 2 (events/s)", xlabel="Time (s)", title=f"proportion var explained (training): {pve[1]:0.3f}") |
| 122 | +axs[1].set( |
| 123 | + ylabel="Output 2 (events/s)", |
| 124 | + xlabel="Time (s)", |
| 125 | + title=f"proportion var explained (training): {pve[1]:0.3f}", |
| 126 | +) |
120 | 127 |
|
121 | | -axs[2].plot(t, u[eg_trial].T, 'k') |
| 128 | +axs[2].plot(t, u[eg_trial].T, "k") |
122 | 129 | axs[2].set(ylabel="Input (a.u.)", xlabel="Time (s)") |
123 | 130 |
|
124 | 131 | fig.tight_layout() |
|
131 | 138 |
|
132 | 139 | # %% |
133 | 140 | # Refit by E-M |
134 | | -calc_dynamics = True #calculate dynamics (A, B mats) |
135 | | -calc_Q = True #calculate process noise cov (Q) |
136 | | -calc_init = True #calculate initial conditions |
137 | | -calc_output = True #calculate output (C) |
138 | | -calc_measurement = True #calculate output noise (R) |
| 141 | +calc_dynamics = True # calculate dynamics (A, B mats) |
| 142 | +calc_Q = True # calculate process noise cov (Q) |
| 143 | +calc_init = True # calculate initial conditions |
| 144 | +calc_output = True # calculate output (C) |
| 145 | +calc_measurement = True # calculate output noise (R) |
139 | 146 | max_iter = 50 |
140 | 147 | tol = 1e-2 |
141 | 148 |
|
142 | 149 | em = plds.FitEM(fit, u_train, z_train) |
143 | 150 |
|
144 | 151 | start = time.perf_counter() |
145 | | -fit_em = em.Run(calc_dynamics, calc_Q, calc_init, calc_output, calc_measurement, max_iter, tol) |
| 152 | +fit_em = em.Run( |
| 153 | + calc_dynamics, calc_Q, calc_init, calc_output, calc_measurement, max_iter, tol |
| 154 | +) |
146 | 155 | stop = time.perf_counter() |
147 | 156 | print(f"Finished EM fit in {(stop-start)*1000} ms.") |
148 | 157 |
|
|
162 | 171 | axs[0].legend(["measurement", "EM re-estimated"]) |
163 | 172 | axs[0].set(ylabel="Output (events/s)") |
164 | 173 |
|
165 | | -axs[1].plot(t, z[eg_trial][0, :], 'k') |
| 174 | +axs[1].plot(t, z[eg_trial][0, :], "k") |
166 | 175 |
|
167 | | -axs[2].plot(t, u[eg_trial].T, 'k') |
| 176 | +axs[2].plot(t, u[eg_trial].T, "k") |
168 | 177 | axs[2].set(ylabel="Input (a.u.)", xlabel="Time (s)") |
169 | 178 |
|
170 | 179 | fig.tight_layout() |
|
175 | 184 | fig, axs = plt.subplots(1, 2) |
176 | 185 |
|
177 | 186 | l1 = axs[1].plot(t_imp, y_imp[0].T, "-", c="k", linewidth=2) |
178 | | -l2 = axs[1].plot(t_imp, y_imp_hat_em[0].T, "-", c='gray', linewidth=2) |
| 187 | +l2 = axs[1].plot(t_imp, y_imp_hat_em[0].T, "-", c="gray", linewidth=2) |
179 | 188 | axs[1].legend([l1[0], l2[0]], ["ground truth", "EM re-estimated"]) |
180 | 189 | axs[1].set(ylabel="Impulse Response (a.u.)", xlabel="Time (s)") |
181 | 190 | fig.tight_layout() |
|
192 | 201 | axs[0].plot(t, y[eg_trial][0, :] / dt, "k-") |
193 | 202 | axs[0].plot(t, y_hat_em[eg_trial][0, :] / dt, "-", c="gray", linewidth=2) |
194 | 203 | axs[0].legend(["measurement", "EM re-estimated"]) |
195 | | -axs[0].set(ylabel="Output 1 (a.u.)", xlabel="Time (s)", title=f"EM-refit proportion var explained (training): {pve_em[0]:0.3f}") |
| 204 | +axs[0].set( |
| 205 | + ylabel="Output 1 (a.u.)", |
| 206 | + xlabel="Time (s)", |
| 207 | + title=f"EM-refit proportion var explained (training): {pve_em[0]:0.3f}", |
| 208 | +) |
196 | 209 |
|
197 | 210 | axs[1].plot(t, y[eg_trial][1, :] / dt, "k-") |
198 | 211 | axs[1].plot(t, y_hat_em[eg_trial][1, :] / dt, "-", c="gray", linewidth=2) |
199 | | -axs[1].set(ylabel="Output 2 (a.u.)", xlabel="Time (s)", title=f"EM-refit proportion var explained (training): {pve_em[1]:0.3f}") |
| 212 | +axs[1].set( |
| 213 | + ylabel="Output 2 (a.u.)", |
| 214 | + xlabel="Time (s)", |
| 215 | + title=f"EM-refit proportion var explained (training): {pve_em[1]:0.3f}", |
| 216 | +) |
200 | 217 |
|
201 | | -axs[2].plot(t, u[eg_trial].T, 'k') |
| 218 | +axs[2].plot(t, u[eg_trial].T, "k") |
202 | 219 | axs[2].set(ylabel="Input (a.u.)", xlabel="Time (s)") |
203 | 220 |
|
204 | 221 | fig.tight_layout() |
|
0 commit comments