-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathevaluateclass.m
More file actions
78 lines (76 loc) · 2.71 KB
/
evaluateclass.m
File metadata and controls
78 lines (76 loc) · 2.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
function [ P, stats ] = evaluateclass( G, Ghat, varargin )
p = inputParser;
validGroup = @(x) validateattributes(x, {'logical', 'numeric'}, ...
{'nonempty'});
allStats = {'Accuracy', 'Prevalence', 'Precision', 'NPV', 'Recall', ...
'Specificity', 'F1score', 'MCC', 'Informedness', 'Markedness'}';
validCriterion = @(x) any(validatestring(x, allStats));
addRequired(p, 'G', validGroup);
addRequired(p, 'Ghat', validGroup);
addOptional(p, 'Criterion', 'mcc', validCriterion);
parse(p, G, Ghat);
invalidateMismatch(G, Ghat, 'G', 'Ghat', 'row');
invalidateMismatch(G, Ghat, 'G', 'Ghat', 'dimension');
nObs = size(G, 1);
dim = size(G);
if(numel(dim) == 2); dim = [dim, 1]; end
if(numel(dim) == 1); dim = [dim, 1, 1]; end
criterionIndex = strcmpi(allStats, p.Results.Criterion);
P = zeros(dim(2:end));
for i = (numel(G) / nObs):-1:1
ind = (1:nObs) + (i - 1) * nObs;
[conf] = confusionmat(G(ind), Ghat(ind));
if(nargout >= 2)
S = computeStats(conf, allStats);
stats(i) = cell2struct(S, allStats);
P(i) = S{criterionIndex};
else
S = computeStats(conf, {p.Results.Criterion});
P(i) = S{1};
end
end
end
function [ S ] = computeStats(conf, statList)
S = cell(size(statList));
TP = conf(2, 2);
TN = conf(1, 1);
FP = conf(1, 2);
FN = conf(2, 1);
nObs = sum(sum(conf));
for i = 1:length(statList)
switch(lower(statList{i}))
case 'accuracy'
S{i} = (TP + TN) / nObs;
case 'prevalence'
S{i} = (TP + FN) / nObs;
case 'precision'
S{i} = TP / (TP + FP);
case 'npv'
S{i} = TN / (TN + FN);
case 'recall'
S{i} = TP / (TP + FN);
case 'specificity'
S{i} = TN / (TN + FP);
case 'f1score'
S{i} = (2 * TP) / (2 * TP + FP + FN);
case 'mcc'
S{i} = (TP*TN - FP*FN) / sqrt((TP + FP) * (TP + FN) * ...
(TN + FP) * (TN + FN));
case 'informedness'
S{i} = (TP / (TP + FN)) + (TN / (TN + FP)) - 1;
case 'markedness'
S{i} = (TP / (TP + FP)) + (TN / (TN + FN)) - 1;
otherwise
error(['Invalid statistic requested:', statList{i}]);
end
end
end
function value = getfieldi(S,field)
names = fieldnames(S);
isField = strcmpi(field,names);
if any(isField)
value = S.(names{isField});
else
value = [];
end
end