Skip to content

Commit 53f1382

Browse files
committed
fix weights
1 parent 86d2bba commit 53f1382

File tree

2 files changed

+255
-9
lines changed

2 files changed

+255
-9
lines changed
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
function [major_classes, full_distributions, seg_classes, class_w] = distr_strategies(segmentation_configs, classification_configs, varargin)
2+
%DISTR_STRATEGIES computes the prefered strategy for a small time window
3+
%for each trajectory.
4+
5+
%% OPTIONS (INPUT) %%
6+
%-sigma: controls how many segments influence the choice of segment
7+
% classes over the swimming paths.
8+
%-discard_undefined: keep/discard undefined
9+
%-w: method of computed the weights
10+
%-norm_method: normalizes the weights (off/OFF/0) to skip
11+
%-tiny_num: used for computed the full_distributions (data smoothing)
12+
sigma = 4;
13+
discard_undefined = 0;
14+
%w = 'defined';
15+
w = 'computed';
16+
norm_method = 'off';
17+
hard_bounds = 'on';
18+
%tiny_num = 1e-6;
19+
tiny_num = realmin;
20+
min_seg = 1;
21+
22+
%% INITIALIZE USER INPUT %%
23+
for i = 1:length(varargin)
24+
if isequal(varargin{i},'sigma')
25+
sigma = varargin{i+1};
26+
if sigma <= 0
27+
sigma = 1;
28+
end
29+
elseif isequal(varargin{i},'discard_undefined')
30+
if isequal(varargin{i+1},1) || isequal(varargin{i+1},'on') || isequal(varargin{i+1},'ON')
31+
discard_undefined = 1;
32+
end
33+
elseif isequal(varargin{i},'smoothing')
34+
min_seg = varargin{i+1};
35+
if min_seg <= 0
36+
min_seg = 1;
37+
end
38+
elseif isequal(varargin{i},'weights')
39+
w = varargin{i+1};
40+
elseif isequal(varargin{i},'norm_method')
41+
norm_method = varargin{i+1};
42+
elseif isequal(varargin{i},'hard_bounds')
43+
hard_bounds = varargin{i+1};
44+
end
45+
end
46+
47+
%% INITIALIZATION %%
48+
class_map = classification_configs.CLASSIFICATION.class_map;
49+
%segment length
50+
lengths = segmentation_configs.FEATURES_VALUES_SEGMENTS(:,10);
51+
%class slots
52+
nbins = max(segmentation_configs.PARTITION);
53+
%segment class slots
54+
seg_classes = zeros(1,length(class_map));
55+
%all classes
56+
classes = zeros(1,length(classification_configs.CLASSIFICATION_TAGS));
57+
for i = 1:length(classes)
58+
classes(i) = classification_configs.CLASSIFICATION_TAGS{i}{1,3};
59+
end
60+
%class weights
61+
class_w = zeros(1,length(classes));
62+
switch w
63+
case 'ones'
64+
class_w = ones(1,length(classes));
65+
case 'defined'
66+
for i = 1:length(classes)
67+
class_w(i) = classification_configs.CLASSIFICATION_TAGS{i}{1,4};
68+
end
69+
case 'computed'
70+
class_w = computed_weights(segmentation_configs, classification_configs);
71+
end
72+
%weights normalization
73+
if ~isequal(norm_method,'off') && ~isequal(norm_method,'OFF') && ~isequal(norm_method,0)
74+
class_w = normalizations(class_w,norm_method);
75+
end
76+
%hard bounds
77+
if ~isequal(hard_bounds,'off') && ~isequal(hard_bounds,'OFF') && ~isequal(hard_bounds,0);
78+
avg_w = max(class_w) - min(class_w);
79+
avg_w = avg_w / 2;
80+
avg_w = avg_w + min(class_w);
81+
for i = 1:length(class_w)
82+
if class_w(i) < avg_w
83+
class_w(i) = class_w;
84+
else
85+
class_w(i) = max(class_w);
86+
end
87+
end
88+
end
89+
90+
%strategies distribution
91+
class_distr_traj = [];
92+
%array that shows numerically distribution values
93+
full_distributions = [];
94+
%final strategy distribution
95+
major_classes = [];
96+
%
97+
undef = [];
98+
%for matching segments to trajectory
99+
id = [-1, -1, -1];
100+
%the ith path segment, Si
101+
iseg = 0;
102+
103+
%% PROCESSING %%
104+
for i = 1:classification_configs.CLASSIFICATION.segments.count
105+
segment = classification_configs.CLASSIFICATION.segments.items(i);
106+
if ~isequal(id, segment.data_identification)
107+
%take the segment id
108+
id = segment.data_identification;
109+
%distribute the classes
110+
if ~isempty(class_distr_traj)
111+
% %full distributions
112+
% tmp = class_distr_traj;
113+
% tmp(tmp(:) == -1) = 0;
114+
% if isempty(undef) %we do not have undefined
115+
% nrm = repmat(sum(tmp,2) + tiny_num, undef', 1, length(classes));
116+
% else
117+
% nrm = repmat(sum(tmp,2) + tiny_num, undef', 1, length(classes)-1);
118+
% end
119+
% nrm(class_distr_traj == -1) = 1;
120+
% class_distr_traj = class_distr_traj ./ nrm;
121+
% full_distributions = [full_distr, class_distr_traj];
122+
123+
%take only the most frequent class for each bin and traj
124+
traj_distr = zeros(1,nbins);
125+
for j = 1:nbins
126+
[val,pos] = max(class_distr_traj(j,:));
127+
if val > 0
128+
if undefined(j) > val && discard_undefined
129+
traj_distr(j) = 0;
130+
else
131+
traj_distr(j) = pos;
132+
end
133+
else
134+
if j > iseg
135+
traj_distr(j) = -1;
136+
else
137+
traj_distr(j) = 0;
138+
end
139+
end
140+
end
141+
major_classes = [major_classes; traj_distr];
142+
end
143+
144+
%exclude the undefined
145+
undef = find(classes == 0);
146+
if isempty(undef)
147+
class_distr_traj = ones(nbins,length(classes))*-1;
148+
else
149+
class_distr_traj = ones(nbins,length(classes)-1)*-1;
150+
end
151+
undefined = zeros(1,nbins);
152+
iseg = 0;
153+
end
154+
iseg = iseg + 1;
155+
156+
wi = iseg; %current segment
157+
wf = iseg; %overlapped segments
158+
coverage = segment.offset + lengths(i);
159+
for j = i+1 : classification_configs.CLASSIFICATION.segments.count
160+
segment_ = classification_configs.CLASSIFICATION.segments.items(j);
161+
if ~isequal(id,segment_.data_identification) || segment_.offset > coverage
162+
wf = iseg - 1 + j - i - 1;
163+
break;
164+
end
165+
end
166+
167+
% mid-point
168+
m = (wi + wf) / 2;
169+
for j = wi:wf
170+
if class_map(i) > 0
171+
col = class_map(i);
172+
%equation 2, supplementary material page 7
173+
val = class_w(col)*exp(-(j-m)^2 / (2*sigma^2));
174+
if class_distr_traj(j,col) == -1
175+
class_distr_traj(j,col) = val;
176+
else
177+
class_distr_traj(j,col) = class_distr_traj(j,col) + val;
178+
end
179+
elseif discard_undefined
180+
undefined(j) = undefined(j) + 1;
181+
end
182+
end
183+
end
184+
185+
%% GENERATE RESULTS %%
186+
if ~isempty(class_distr_traj)
187+
%FULL DISTRIBUTIONS
188+
tmp = class_distr_traj;
189+
tmp(tmp(:) == -1) = 0;
190+
if isempty(undef) %we do not have undefined
191+
nrm = repmat(sum(tmp,2) + tiny_num, 1, length(classes));
192+
else
193+
nrm = repmat(sum(tmp,2) + tiny_num, 1, length(classes)-1);
194+
end
195+
nrm(class_distr_traj == -1) = 1;
196+
full_distributions = class_distr_traj ./ nrm;
197+
198+
%MAJOR CLASSES
199+
for j = 1:nbins
200+
[val,pos] = max(class_distr_traj(j,:));
201+
if val > 0
202+
traj_distr(j) = pos;
203+
else
204+
if j > iseg
205+
traj_distr(j) = -1;
206+
else
207+
traj_distr(j) = 0;
208+
end
209+
end
210+
end
211+
major_classes = [major_classes; traj_distr];
212+
213+
%Extra: remove spurious segments (or "smooth" the data)
214+
if min_seg > 1
215+
for i = 1:size(major_classes, 1)
216+
j = 1;
217+
lastc = -1;
218+
lasti = 0;
219+
while(j <= size(major_classes, 2) && major_classes(i, j) ~= -1)
220+
if lastc == -1
221+
lastc = major_classes(i, j);
222+
lasti = j;
223+
elseif major_classes(i, j) ~= lastc
224+
if (j - lasti) < min_seg && lastc ~= 0
225+
if lasti > 1
226+
% find middle point
227+
m = floor( (j + lasti) / 2);
228+
major_classes(i, lasti:m) = major_classes(i, lasti - 1);
229+
major_classes(i, m + 1:j) = major_classes(i, j);
230+
end
231+
end
232+
lastc = major_classes(i, j);
233+
lasti = j;
234+
end
235+
j = j + 1;
236+
end
237+
end
238+
end
239+
240+
%FINAL SEGMENTS
241+
%re-map distribution to the flat list of segments
242+
index = 1;
243+
partitions = segmentation_configs.PARTITION;
244+
partitions = partitions(partitions > 0);
245+
for i = 1:length(partitions)
246+
seg_classes(index : index + partitions(i) - 1) = major_classes(i, 1:partitions(i));
247+
index = index + partitions(i);
248+
end
249+
end
250+
end
251+

strategy_distribution/distr_strategies.m

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,20 +75,15 @@
7575
end
7676
%hard bounds
7777
if ~isequal(hard_bounds,'off') && ~isequal(hard_bounds,'OFF') && ~isequal(hard_bounds,0);
78-
avg_w = max(class_w)/2;
79-
avg_w = avg_w-1;
80-
avg_w = avg_w/2;
78+
avg_w = max(class_w) - min(class_w);
79+
avg_w = avg_w / 2;
80+
avg_w = avg_w + min(class_w);
8181
for i = 1:length(class_w)
8282
if class_w(i) < avg_w
83-
class_w(i) = 1;
83+
class_w(i) = min(class_w);
8484
else
8585
class_w(i) = max(class_w);
8686
end
87-
% if class_w(i) < 5
88-
% class_w(i) = 1;
89-
% else
90-
% class_w(i) = 10;
91-
% end
9287
end
9388
end
9489

0 commit comments

Comments
 (0)