forked from ahmadianlab/gg3_nda
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtask_4_2_fano.py
More file actions
68 lines (55 loc) · 1.77 KB
/
task_4_2_fano.py
File metadata and controls
68 lines (55 loc) · 1.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# import
from inference import *
from HMM_models import *
import matplotlib.pyplot as plt
import numpy as np
#common parameters
x0 = 0.2
Rh = 75
T = 100
K = 100
# define parameters ramp
beta = 2
sigma = 2
# define parameters step
m = 40
r = 5
def trial_fano(model, iterations, t):
arr = np.empty([T])
for i in range(iterations):
latent, rate, spikes = model.simulate()
arr = np.vstack((arr, spikes))
mean = np.array([])
var = np.array([])
for i in range(t):
sum = 0
sqr_sum = 0
for j in range(iterations):
sum += arr[j][i]
sqr_sum += arr[j][i] ** 2
mean = np.append(mean, sum / iterations)
var = np.append(var, sqr_sum / iterations - (sum / iterations)**2)
fano = var / mean
return mean, var, fano
for gamma in range(2,6):
model = HMM_Ramp(beta, sigma, K, x0, Rh, T, isi_gamma_shape = gamma)
mean, var, fano = trial_fano(model, 1000, T)
spike_times = np.linspace(0, 1, num = fano.shape[0], endpoint = False)
plt.plot(spike_times, fano, label = 'PSTH for $\gamma$ = '+str(gamma))
plt.title('Fano factor of ramp model ' + '$\\beta$=' + str(beta) + ' $\sigma$=' + str(sigma))
plt.xlabel('time (s)')
plt.legend()
plt.show()
for gamma in range(2,6):
model = HMM_Step(m, r, x0, Rh, T, isi_gamma_shape = gamma)
mean, var, fano = trial_fano(model, 1000, T)
spike_times = np.linspace(0, 1, num = fano.shape[0], endpoint = False)
plt.plot(spike_times, fano, label = 'PSTH for $\gamma$ = '+str(gamma))
plt.title('Fano factor of step model '+'m ='+str(m)+' r='+str(r))
plt.xlabel('time (s)')
plt.legend()
plt.show()
#plt.plot(spike_times, mean, label = 'mean')
#plt.plot(spike_times, var, label = 'var')
#plt.plot(spike_times, fano, label = 'fano')
#plt.show()