-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathselectTaskELLA.m
More file actions
86 lines (86 loc) · 2.93 KB
/
selectTaskELLA.m
File metadata and controls
86 lines (86 loc) · 2.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
%%
% Select the next task to learn using a specified particular criterion
%
% inputs -
% model: the ELLA model
% Xs: a cell array containing some seed data to use for selecting the next task
% to learn
% Ys: a cell array containing the labels of the seed data
% selectionCriterion: 1 (random)
% 2 (InfoMax)
% 3 (diversity)
% 4 (diversity++)
% Xtarget (optional): the data for the target task
% Ytarget (optional): the labels for the target task
%
% outputs -
% taskid: the next task to learn (as selected by the speciied selection
% criterion)
%
% Copyright (C) Paul Ruvolo and Eric Eaton 2013
%
% This file is part of ELLA.
%
% ELLA is free software: you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation, either version 3 of the License, or
% (at your option) any later version.
%
% ELLA is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
% GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License
% along with ELLA. If not, see <http://www.gnu.org/licenses/>.
function taskid = selectTaskELLA(model,Xs,Ys,selectionCriterion,Xtarget,Ytarget)
doTargeted = nargin == 6;
if ~ismember(selectionCriterion,[1:5])
error('Invalid Task Selection Criterion Specified');
end
if doTargeted & selectionCriterion ~= 2
error('Targeted selection only works with InfoMax');
end
if selectionCriterion == 1 % random selection
taskid = randint(1,1,[1 length(Ys)]);
return;
end
taskGoodness = zeros(length(Ys),1);
if doTargeted
[sTarget thetaTarget DTarget taskSpecificTarget] = ...
encodeTaskELLA(model,Xtarget,Ytarget);
end
for t = 1 : length(Ys)
[sCurr wCurr DCurr taskSpecificCurr] = ...
encodeTaskELLA(model,Xs{t},Ys{t});
if selectionCriterion == 2
ACurr = (model.A + kron(sCurr*sCurr',DCurr))./(model.T+1) + model.lambda*eye(model.d*model.k);
if ~doTargeted
% compute using the d-optimality criterion
taskGoodness(t) = logdet(ACurr);
else
% compute using the d-optimality for the parameter vector for the task
Psi = kron(sTarget, eye(model.d));
taskGoodness(t) = logdet(Psi'*ACurr*Psi);
end
end
if selectionCriterion == 3 | selectionCriterion == 4
if sum(sCurr) ~= 0
% encode using the difference in loss between single task model and encoded model (Diversity Heuristic)
taskGoodness(t) = (wCurr - model.L*sCurr - taskSpecificCurr)'*DCurr*(wCurr - model.L*sCurr - taskSpecificCurr);
end
end
end
if selectionCriterion ~= 4
% always pick the best task
[dc taskid] = max(taskGoodness);
else
% select task probabilistically
if sum(taskGoodness) == 0
taskid = 1;
else
probs = taskGoodness.^2./sum(taskGoodness.^2);
taskid = min(find(cumsum(probs)>rand));
end
end
end