Skip to content

Commit af317a7

Browse files
committed
confusion matrix
1 parent 08382b4 commit af317a7

File tree

3 files changed

+91
-27
lines changed

3 files changed

+91
-27
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
ppath = 'I:\Documents\MWMGEN\tiago_original\Mclassification\class_1301_10388_250_07_10_10_mr0';
2+
output_dir = '\\staffstore\avgoustinos\.redirect\Desktop\res';
3+
4+
folds = 10;
5+
files = dir(fullfile(ppath,'*.mat'));
6+
for i = 1:length(files)
7+
output_dir = fullfile(output_dir,files(i).name);
8+
mkdir(output_dir);
9+
load(fullfile(ppath,files(i).name));
10+
results_confusion_matrix(classification_configs,folds,output_dir);
11+
end
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
function results_confusion_matrix(segmentation_configs,classification_configs,folds,output_dir)
2+
% Computes the confusion matrix for the classification of segments.
3+
% Values are the total number of missclassifications for a 10-fold
4+
% cross-validation of the clustering algorithm
5+
6+
fn = fullfile(output_dir, 'confusion_matrix.mat');
7+
fn2 = fullfile(output_dir, 'Confusion_matrix.txt');
8+
9+
% get classifier object
10+
classif = classification_configs.SEMISUPERVISED_CLUSTERING;
11+
defaults_clusters = classification_configs.DEFAULT_NUMBER_OF_CLUSTERS;
12+
13+
% check if the folds number is too high
14+
a = length(classification_configs.CLASSIFICATION.class_map);
15+
if iscell(folds)
16+
if isempty(str2num(folds{1,1}))
17+
disp('Insert a number to specify the N-fold cross-validation.');
18+
return
19+
end
20+
elseif ~isnumeric(folds)
21+
if isempty(str2num(folds))
22+
disp('Insert a number to specify the N-fold cross-validation.');
23+
return
24+
end
25+
if a - folds <= 0 || a - folds < a/10
26+
disp('Input number too high, insert a lower number.');
27+
return
28+
end
29+
else
30+
if a - folds <= 0 || a - folds < a/10
31+
disp('Input number too high, insert a lower number.');
32+
return
33+
end
34+
end
35+
36+
% perform a N-fold cross-validation
37+
res = classif.cluster_cross_validation(defaults_clusters, 'Folds', folds);
38+
39+
% take the "total confusion matrix"
40+
cm = res.results(1).confusion_matrix;
41+
for i = 2:folds
42+
cm = cm + res.results(i).confusion_matrix;
43+
end
44+
tags = res.results(1).classes;
45+
46+
% NaN values
47+
for i = 1:size(cm,1)
48+
for j = 1:size(cm,2)
49+
if isnan(cm)
50+
cm(i,j) = 0;
51+
end
52+
end
53+
end
54+
55+
% save data
56+
save(fn, 'tags', 'cm');
57+
disp('Tags:');
58+
for i = 1:length(tags)
59+
fprintf('%s\n', tags{1,i}{1,2});
60+
end
61+
fprintf('\nConfusion matrix:\n');
62+
cm
63+
64+
% save to file
65+
fileID = fopen(fn2,'wt');
66+
for i = 1:length(tags)
67+
fprintf(fileID,'%d. %s\n', i, tags{1,i}{1,2});
68+
end
69+
fprintf(fileID, '\n');
70+
for i=1:size(cm,1)
71+
fprintf(fileID, '%d ', cm(i,:));
72+
fprintf(fileID, '\n');
73+
end
74+
fclose(fileID);
75+
76+
end

unused/results_confusion_matrix.m

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,15 @@
1-
function results_confusion_matrix(segmentation_configs,classification_configs,folds)
1+
function results_confusion_matrix(classification_configs,folds,output_dir)
22
% Computes the confusion matrix for the classification of segments.
3-
% Values are the total number of missclassifications for a 10-fold
3+
% Values are the total number of missclassifications for a N-fold
44
% cross-validation of the clustering algorithm
55

6-
fn = fullfile(strcat(segmentation_configs.OUTPUT_DIR,'/'), 'confusion_matrix.mat');
7-
fn2 = fullfile(strcat(segmentation_configs.OUTPUT_DIR,'/'), 'Confusion_matrix.txt');
6+
fn = fullfile(output_dir, 'confusion_matrix.mat');
7+
fn2 = fullfile(output_dir, 'Confusion_matrix.txt');
88

99
% get classifier object
1010
classif = classification_configs.SEMISUPERVISED_CLUSTERING;
1111
defaults_clusters = classification_configs.DEFAULT_NUMBER_OF_CLUSTERS;
1212

13-
% check if the folds number is too high
14-
a = length(classification_configs.CLASSIFICATION.class_map);
15-
if iscell(folds)
16-
if isempty(str2num(folds{1,1}))
17-
disp('Insert a number to specify the N-fold cross-validation.');
18-
return
19-
end
20-
elseif ~isnumeric(folds)
21-
if isempty(str2num(folds))
22-
disp('Insert a number to specify the N-fold cross-validation.');
23-
return
24-
end
25-
if a - folds <= 0 || a - folds < a/10
26-
disp('Input number too high, insert a lower number.');
27-
return
28-
end
29-
else
30-
if a - folds <= 0 || a - folds < a/10
31-
disp('Input number too high, insert a lower number.');
32-
return
33-
end
34-
end
35-
3613
% perform a N-fold cross-validation
3714
res = classif.cluster_cross_validation(defaults_clusters, 'Folds', folds);
3815

0 commit comments

Comments
 (0)