-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathKTS.m
More file actions
274 lines (245 loc) · 12.4 KB
/
KTS.m
File metadata and controls
274 lines (245 loc) · 12.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
%%
% Adapted from Kording, Tenembaum, and Shadmehr 2007.
%
%% Joint parameters - initialize transition and observation matrices across all examples.
% Construct these matrices
states=30; %how many hidden states to use - we want 30 here
taus=exp(-linspace(log(0.000003),log(0.5),states)); %calculate the timescales -
A=diag(1-1./taus); % the transition matrix A - diagnal matrix using equation (1) (matrix M from appendix)
C = ones(1,states); % the observation matrix C - matrix H from appendix.
Q = diag(1./taus); % the state noise matrix Q - matrix Q from appendix
Q=0.000001475*Q/sum(Q(:)); % this trick with normalizing makes it easier
% to experiment with other power laws for Q
% This way c=0.001 but its easy to play with the
% parameters
R = (0.05).^2; % the observation noise R - sigma_w^2 from appendix
initx = zeros(states,1); %system starts out in unperturbed state
initV = diag(1e-6*ones(states,1)); %% rough estimate of variance
% Note that there are 30 memory states (30 'disturbances'), which will produce the error. The gain is implicit - error is always
% assumed to be around 0 (errors are addative). Not solving a multiplicative gain, but an addative offset.
%
% Trying to find the states that when summed equal the error, so thats why C is ones
% ypred = C*xpred, which can be compared to any y
%% Improve estimate of initV
% This is just to get a better estimate of initV, and shows an example of
% the use of sample_lds and kalman_filter.
T = 40000;
[x0,y0] = sample_lds(A, C, Q, R, initx, T);
[xfilt, Vfilt, VVfilt, loglik, xpred] = kalman_filter(y0, A, C, Q, R, initx, initV);
initV=Vfilt(:,:,end);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Replicate Hopp and Fuchs
% These reproduce figure 2b - simulates the data (sample_lds) and then perform inference on them.
%{
In these experiments, subjects had to make a 10 degree saccade (to a particular target),
but then (after 200 ‘practice’ trials) the target moves during the saccade.
Since saccades are too short to be able to use the feedback to adjust mid-saccade,
the subject can only adjust the gain on a trail by trial basis.
The subject eventually adapts by reducing the gain on the saccade (for an overshot).
Then at trial 1400, the target stops moving, and the subject adapts back to the original location.
%}
T = 4200; % Timescale of experiment, ie number of trials.
% INSERT CODE
% Simulate the unperturbed plan using sample_lds, and then purturb it:
[x0,y] = sample_lds(A, C, Q, R, initx, T); % simulate an unperturbed plant
y(1201:2600)=y(1201:2600)-0.3; % external perturbation on y - from 1200 to 2600, produce a 30% disturbance
% Then run the filtration on it, producing xpred to be used below (graphing).
[xfilt, Vfilt, VVfilt, loglik, xpred] = kalman_filter(y, A, C, Q, R, initx, initV);
% HERE
a=sum(xpred(:,1001:end)); % first 1000 trials are just to ensure initV doesnt matter much
% We weigh all the hidden states equally, so just sum across all of them
% to get the resulting gain, and then subtract the observations and reverse
% the external perturbation.
b=a-y(1001:end); % Subjectract the gain y from the predicted values x.
b(201:1600)=b(201:1600)-0.3; %transform back because we want to plot the actual gain
%% Produce figure 2b, given xpred and y.
figure(1);
subplot(2,1,1)
plot(1+b,'.') % Plot the 'trials'
[paras]=fitExponential([1:1400],b(201:1600)); % Fit an exponential to the trial data
[paras2]=fitExponential([1:1600],b(1601:3200));
hold on
plot([201:1600],1+paras(1)+paras(2)*exp(paras(3)*[1:1400]),'r'); % Plot the exponentials
plot([1601:3200],1+paras2(1)+paras2(2)*exp(paras2(3)*[1:1600]),'r');b
xlabel('time (saccades)')
ylabel('relative size of saccade');
title('Replication of standard target jump paradigm')
subplot(2,1,2)
imagesc(xpred(:,1001:end),[-1 1]*max(abs(xpred(:)))); % Plot the values of the hidden states.
colorbar
xlabel('time (saccades)')
ylabel('log timescale (2-33000)');
title('Inferred disturbances for all timescales');
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Replicate Kojima faster second time
% This replicates figure 3c.
%{
Double adaptation experiments. Following an initial set of trials,
there is positive perturbation of the target by 35% for 800 trials,
followed by a 35% negative perturbation of the target (from the initial position).
This continues until the gain is back to neutral (i.e. the subject correctly saccades to the initial position),
at which time the target is (again) positively perturbed by 35%.
This second positive perturbation is followed by a quicker adaptation towards the new target position,
indicating that some memory remains of the previous positive perturbation.
%}
T=4200;
for i=1:5
% INSERT CODE
% First, simulate the full trial sets:
[x0,y] = sample_lds(A, C, Q, R, initx, T); %simulate the plant
% Then produce positive perturbation from 1001 till 1800, and then negative perturbation till the end.
y(1001:1800)=y(1001:1800)+0.35; % positive perturbation
y(1801:end)=y(1801:end)-0.35; % negative perturbation until gain=1
% Then run the Kalman filter on this perturbation.
[xfilt, Vfilt, VVfilt, loglik,xpred] = kalman_filter(y, A, C, Q, R, initx, initV);
% HERE
% Now, figure out when the gain goes back to normal.
nG=find(sum(xpred)<sum(xpred(:,1001)));
bord=min(nG(find(nG>1800)))+1; %figure out when the gain is back to normal
% And then produce the positive perturbation again starting at that point - remember the negative perturbation.
y(bord:end)=y(bord:end)+0.7; %switch back to positive
% Now run the kalman filter again...
[xfilt, Vfilt, VVfilt, loglik,xpred] = kalman_filter(y, A, C, Q, R, initx, initV);
% And produce a linear fit on the first and second positive perturbation.
[parasF]=fitLinear([1:200],sum(xpred(:,1001:1200))-y(1001:1200));
[parasS]=fitLinear([1:200],sum(xpred(:,bord:bord+199))-y(bord:bord+199));
first(i)=parasF(2);
second(i)=parasS(2);
end
%% Produces figure 3c - need xpred = state predicted and y = observation
figure(2);
subplot(2,1,1);
a=sum(xpred(:,801:end));
b=a-y(801:end);
b(201:1000)=b(201:1000)+0.35; % remove perturbation to report standard gains
b(1001:bord-800)=b(1001:bord-800)-0.35;
b(bord-800:end)=b(bord-800:end)+0.35;
plot(1+b,'.');
[paras]=fitLinear([1:200],b(201:400)); % Fit linear functions to the resulting data.
[paras2]=fitLinear([1:200],b(bord-800:bord-800+199));
hold on
plot([201:800],1+paras(1)+paras(2)*[1:600],'r');
plot([bord-800:bord+599-800],1+paras2(1)+paras2(2)*[1:600],'r');
xlabel('time (saccades)')
ylabel('relative size of saccade');
title('adaptation, up down and up again');
subplot(2,1,2)
imagesc(xpred(:,1001:end),[-1 1]*max(abs(xpred(:))));
xlabel('time (saccades)')
ylabel('log timescale (2-33000)');
title('inferred disturbances for all timescales');
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Replicate Kojimas change in the dark experiment
% Figure 3g
%{
Here we have a period with no information - after the gain resets following a reversal,
the subject is blinded (so no information) and then a positive perturbation is produced.
Note that the subject ?lost? some of the recent negative adaptation and showed spontaneous recovery
(in the original graph).
%}
T=4200;
% INSERT CODE
[x0,y] = sample_lds(A, C, Q, R, initx, T); %simulate the plant
y(1001:1800)=y(1001:1800)+0.35; % positive perturbation
y(1801:end)=y(1801:end)-0.35; % negative perturbation until gain=1
[xfilt, Vfilt, VVfilt, loglik,xpred] = kalman_filter(y, A, C, Q, R, initx, initV);
nG=find(sum(xpred)<sum(xpred(:,1001)));
bord=min(nG(find(nG>1800)))+1; %figure out when the gain is back to normal
y(bord:end)=y(bord:end)+0.7; %switch back to positive
% Now run the kalman filter, but set the isObserved to zero for the dark periods.
% The monkey has no observations for this time period, after the gain comes back to 0 for 500 time points.
% We use a time of 500 to represent 'darkness'
[xfilt, Vfilt, VVfilt, loglik,xpred] = kalman_filter(y, A, C, Q, R, initx, initV,'isObserved',[ones(1,bord),zeros(1,500),ones(1,10000)]);
a=sum(xpred);
b=y-a;
b(1001:1800)=b(1001:1800)-0.35;
b(1801:bord)=b(1801:bord)+0.35;
b(bord+1:end)=b(bord+1:end)-0.35;
%% produces figure 3g
figure(3)
clf
subplot(2,1,1)
plot(-b(1,1001:end).*[ones(1,bord-1-1000),zeros(1,501),ones(1,4200-bord-500)],'.')
xlabel('time (saccades)')
ylabel('relative size of saccade');
title('adaptation, up down, darkness, and up again');
subplot(2,1,2)
imagesc(xpred(:,1001:end),[-1 1]*max(abs(xpred(:))));
xlabel('time (saccades)')
ylabel('log timescale (2-33000)');
title('inferred disturbances for all timescales');
clear first second
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Replicate Kojimas no-ISS experiment - no perturbation after darkness
% Figure 3h
%{
This experiment is the same as the previous,
but instead of a positive perturbation after the dark period,
there is no perturbation (so the perturbation is set back to 0).
Note here that there is no spontaneous recovery.
%}
T=4200;
[x0,y] = sample_lds(A, C, Q, R, initx, T); %simulate the plant
y(1001:1800)=y(1001:1800)+0.35; % positive perturbation
y(1801:end)=y(1801:end)-0.35; % negative perturbation until gain=1
[xfilt, Vfilt, VVfilt, loglik,xpred] = kalman_filter(y, A, C, Q, R, initx, initV); % Run kalman filter
nG=find(sum(xpred)<sum(xpred(:,1001)));
bord=min(nG(find(nG>1800)))+1; %figure out when the gain is back to normal (gain = 1)
y(bord:end)=y(bord:end)+0.35; %switch perturbation to 0.
[xfilt, Vfilt, VVfilt, loglik,xpred] = kalman_filter(y, A, C, Q, R, initx, initV,'isObserved',[ones(1,bord),zeros(1,500),ones(1,10000)]);
a=sum(xpred); % Average across all hidden states to find the adaptation gain
b=y-a; % Subtract the resulting averaged gain again from the evidence, and revert the perturbations back to 0.
b(1001:1800)=b(1001:1800)-0.35;
b(1801:bord)=b(1801:bord)+0.35;
%% Plot figure 3h
figure(6)
subplot(2,1,1)
plot(1-b(1,1001:end).*[ones(1,bord-1-1000),zeros(1,501),ones(1,4200-bord-500)],'.')
xlabel('time (saccades)')
ylabel('relative size of saccade');
title('up, down adaptation followed by darkness and no-ISS trials');
subplot(2,1,2)
imagesc(xpred(:,1001:end),[-1 1]*max(abs(xpred(:))));
xlabel('time (saccades)')
ylabel('log timescale (2-33000)');
title('inferred disturbances for all timescales');
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Robinson - memory over multiple days with darkness.
% Figure 2c
%{
In these experiments, subjects went through the adaptation training (with an offset of 50%) over multiple days,
each day having 1500 trials and being blindfolded for the rest of the time.
Note that on each day subject?s rates of learning is faster, till finally they almost achieve instant adaptation.
%}
T = 5*3000; % Produce adaptations for 5 days (1500 trials each).
[x0,y0] = sample_lds(A, C, Q, R, initx, T); % simulate the plant
y=y0-0.5; % perturb the plant by 50%.
isObserved=ones(size(y)); % Set the observability (isObserved) to be 1 for the 'day' and 0 for 'night'. Set each day/night to be 1500 timesteps each.
for i=0:5
isObserved((i*3000)+1:(i*3000)+1500)=0;
end
% Run the kalman filter once more!
[xfilt, Vfilt, VVfilt, loglik, xpred] = kalman_filter(y, A, C, Q, R, initx, initV,'isObserved',isObserved);
% Now subtract the observations from the summed predictions, accounting for
% the perturbations. Let this be variable r.
r=sum(xpred)-y;
r=r-0.5;
%% produce figure 2c - fitting exponentials to each day.
figure(4)
for i=0:4
subplot(2,1,1)
hold on
plot(i*1500+1500+(1:1500),1+r(i*3000+1500+(1:1500)),'ko');
[x4(i+1,:)]=fitExponential([1:1500]/1000,1+r(i*3000+1500+(1:1500)));
plot(i*1500+1500+(1:1500),x4(i+1,1)+x4(i+1,2)*exp(x4(i+1,3)*(1:1500)/1000),'r');
end
xlabel('time (saccades during the experiment)')
ylabel('relative size of saccade');
title('adaptation interleaved with nights in the dark');
subplot(2,1,2)
hold on
imagesc(xpred,[-1 1]*max(abs(xpred(:))));
xlabel('time (saccades, including those simulated in the dark)')
ylabel('log timescale (2-33000)');
title('inferred disturbances for all timescales');
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%