1+ function results_confusion_matrix(segmentation_configs,classification_configs,folds,output_dir)
2+ % Computes the confusion matrix for the classification of segments.
3+ % Values are the total number of missclassifications for a 10-fold
4+ % cross-validation of the clustering algorithm
5+
6+ fn = fullfile(output_dir, 'confusion_matrix.mat');
7+ fn2 = fullfile(output_dir, 'Confusion_matrix.txt');
8+
9+ % get classifier object
10+ classif = classification_configs.SEMISUPERVISED_CLUSTERING;
11+ defaults_clusters = classification_configs.DEFAULT_NUMBER_OF_CLUSTERS;
12+
13+ % check if the folds number is too high
14+ a = length(classification_configs.CLASSIFICATION.class_map);
15+ if iscell(folds)
16+ if isempty(str2num(folds{1,1}))
17+ disp('Insert a number to specify the N-fold cross-validation.');
18+ return
19+ end
20+ elseif ~isnumeric(folds)
21+ if isempty(str2num(folds))
22+ disp('Insert a number to specify the N-fold cross-validation.');
23+ return
24+ end
25+ if a - folds <= 0 || a - folds < a/10
26+ disp('Input number too high, insert a lower number.');
27+ return
28+ end
29+ else
30+ if a - folds <= 0 || a - folds < a/10
31+ disp('Input number too high, insert a lower number.');
32+ return
33+ end
34+ end
35+
36+ % perform a N-fold cross-validation
37+ res = classif.cluster_cross_validation(defaults_clusters, 'Folds', folds);
38+
39+ % take the "total confusion matrix"
40+ cm = res.results(1).confusion_matrix;
41+ for i = 2:folds
42+ cm = cm + res.results(i).confusion_matrix;
43+ end
44+ tags = res.results(1).classes;
45+
46+ % NaN values
47+ for i = 1:size(cm,1)
48+ for j = 1:size(cm,2)
49+ if isnan(cm)
50+ cm(i,j) = 0;
51+ end
52+ end
53+ end
54+
55+ % save data
56+ save(fn, 'tags', 'cm');
57+ disp('Tags:');
58+ for i = 1:length(tags)
59+ fprintf('%s\n', tags{1,i}{1,2});
60+ end
61+ fprintf('\nConfusion matrix:\n');
62+ cm
63+
64+ % save to file
65+ fileID = fopen(fn2,'wt');
66+ for i = 1:length(tags)
67+ fprintf(fileID,'%d. %s\n', i, tags{1,i}{1,2});
68+ end
69+ fprintf(fileID, '\n');
70+ for i=1:size(cm,1)
71+ fprintf(fileID, '%d ', cm(i,:));
72+ fprintf(fileID, '\n');
73+ end
74+ fclose(fileID);
75+
76+ end
0 commit comments