1+ #include < FAST/Algorithms/ImageCaster/ImageCaster.hpp>
2+ #include " OtsuThresholding.hpp"
3+ #include " BinaryThresholding.hpp"
4+
5+ namespace fast {
6+
7+ std::array<int , 256 > calculateGlobalHistogram (ImageAccess::pointer& access, const int width, const int height) {
8+ std::array<int , 256 > hist{};
9+ for (int i = 0 ; i < width*height; ++i) {
10+ auto pixelValue = access->getScalarFast <uchar>(i);
11+ hist[pixelValue]++;
12+ }
13+ return hist;
14+ }
15+
16+
17+ std::array<float , 256 > calculatePixelProbabilities (std::array<int , 256 > histogram, const int pixels) {
18+ std::array<float , 256 > probs;
19+ for (int i = 0 ; i < 256 ; ++i)
20+ probs[i] = (float )histogram[i] / (float )pixels;
21+
22+ return probs;
23+ }
24+
25+ std::array<float , 256 > calculateCumulativePixelProbabilities (std::array<int , 256 > histogram, const int pixels) {
26+ std::array<float , 256 > cumulativeProbs;
27+ cumulativeProbs[0 ] = (float )histogram[0 ] / (float )pixels;
28+ for (int i = 1 ; i < 256 ; ++i)
29+ cumulativeProbs[i] = cumulativeProbs[i-1 ] + (float )histogram[i] / (float )pixels;
30+
31+ return cumulativeProbs;
32+ }
33+
34+
35+ OtsuThresholding::OtsuThresholding (int numberOfClasses) {
36+ createInputPort (0 );
37+ createOutputPort (0 );
38+ if (numberOfClasses > 4 || numberOfClasses < 2 )
39+ throw Exception (" Otsu thresholding implementation only supports 2, 3, or 4 classes" );
40+ m_thresholdCount = numberOfClasses - 1 ;
41+ }
42+
43+ void OtsuThresholding::execute () {
44+ auto input = getInputData<Image>();
45+
46+ if (input->getNrOfChannels () > 1 ) {
47+ reportWarning () << " Otsu thresholding implementation only supports 1 channel images. " <<
48+ " Since your input image has more than 1 channels, only the first channel (red) is used." << reportEnd ();
49+ }
50+ if (input->getDimensions () == 3 ) {
51+ throw Exception (" Otsu thresholding is currently only implemented for 2D images" );
52+ }
53+
54+ if (input->getDataType () != TYPE_UINT8) {
55+ input = ImageCaster::create (TYPE_UINT8, 255 , true )
56+ ->connect (input)
57+ ->runAndGetOutputData <Image>();
58+ // TODO handle issue where bins 0 and 255 become less probable due to rounding..
59+ }
60+
61+ auto access = input->getImageAccess (ACCESS_READ);
62+ // TODO Move histogram operation somewhere else?
63+ auto histogram = calculateGlobalHistogram (access, input->getWidth (), input->getHeight ());
64+ access->release ();
65+ auto cumulativeProbs = calculateCumulativePixelProbabilities (histogram, input->getNrOfVoxels ());
66+ auto probs = calculatePixelProbabilities (histogram, input->getNrOfVoxels ());
67+
68+ std::array<float , 256 > aux;
69+ aux[0 ] = 0 ;
70+ for (int i = 1 ; i < 256 ; ++i) {
71+ aux[i] = aux[i-1 ] + i*probs[i];
72+ }
73+
74+ if (m_thresholdCount == 1 ) {
75+ float bestInterClassVariance = 0 ;
76+ int bestThreshold;
77+ // Iterate over all possible thresholds
78+ for (int T = 1 ; T < 256 ; ++T) {
79+ // Calculate inter-class variance
80+ float w0 = cumulativeProbs[T-1 ];
81+ float w1 = cumulativeProbs[255 ] - cumulativeProbs[T];
82+ float mean0 = aux[T-1 ]/w0;
83+ float mean1 = (aux[255 ] - aux[T])/w1;
84+ float interClassVariance = w0*w1*(mean0 - mean1)*(mean0 - mean1);
85+ // Select threshold with highest inter-class variance
86+ if (interClassVariance > bestInterClassVariance) {
87+ bestInterClassVariance = interClassVariance;
88+ bestThreshold = T;
89+ }
90+ }
91+ // Segment using threshold
92+ addOutputData (0 , BinaryThresholding::create (bestThreshold)->connect (input)->runAndGetOutputData <Image>());
93+ } else if (m_thresholdCount == 2 ) {
94+ const float meanG = aux[255 ];
95+ float bestInterClassVariance = 0 ;
96+ int bestThreshold1;
97+ int bestThreshold2;
98+ // Iterate over all possible thresholds
99+ for (int T1 = 1 ; T1 < 256 -1 ; ++T1) {
100+ for (int T2 = T1+1 ; T2 < 256 ; ++T2) {
101+ // Calculate inter-class variance
102+ float w0 = cumulativeProbs[T1-1 ];
103+ float w1 = cumulativeProbs[T2-1 ] - cumulativeProbs[T1];
104+ float w2 = cumulativeProbs[255 ] - cumulativeProbs[T2];
105+ float mean0 = aux[T1-1 ]/w0;
106+ float mean1 = (aux[T2-1 ] - aux[T1])/w1;
107+ float mean2 = (aux[255 ] - aux[T2])/w2;
108+ float interClassVariance = w0*(mean0 - meanG)*(mean0 - meanG) + w1*(mean1 - meanG)*(mean1 - meanG) + w2*(mean2 - meanG)*(mean2 - meanG);
109+ // Select threshold with highest inter-class variance
110+ if (interClassVariance > bestInterClassVariance) {
111+ bestInterClassVariance = interClassVariance;
112+ bestThreshold1 = T1;
113+ bestThreshold2 = T2;
114+ }
115+ }
116+ }
117+ // Segment using multiple threshold
118+ auto segmentation = Image::create (input->getWidth (), input->getHeight (), TYPE_UINT8, 1 );
119+ segmentation->setSpacing (input->getSpacing ());
120+ auto segmentationAccess = segmentation->getImageAccess (ACCESS_READ_WRITE);
121+ auto inputAccess = input->getImageAccess (ACCESS_READ);
122+ for (int i = 0 ; i < input->getNrOfVoxels (); ++i) {
123+ auto value = inputAccess->getScalarFast <uchar>(i);
124+ uchar segmentationClass = 0 ;
125+ if (value >= bestThreshold1 && value < bestThreshold2) {
126+ segmentationClass = 1 ;
127+ } else if (value >= bestThreshold2) {
128+ segmentationClass = 2 ;
129+ }
130+ segmentationAccess->setScalarFast (i, segmentationClass);
131+ }
132+ addOutputData (0 , segmentation);
133+ } else if (m_thresholdCount == 3 ) {
134+ const float meanG = aux[255 ];
135+ float bestInterClassVariance = 0 ;
136+ int bestThreshold1;
137+ int bestThreshold2;
138+ int bestThreshold3;
139+ // Iterate over all possible thresholds
140+ for (int T1 = 1 ; T1 < 256 -2 ; ++T1) {
141+ for (int T2 = T1+1 ; T2 < 256 -1 ; ++T2) {
142+ for (int T3 = T2+1 ; T3 < 256 ; ++T3) {
143+ // Calculate inter-class variance
144+ float w0 = cumulativeProbs[T1-1 ];
145+ float w1 = cumulativeProbs[T2-1 ] - cumulativeProbs[T1];
146+ float w2 = cumulativeProbs[T3-1 ] - cumulativeProbs[T2];
147+ float w3 = cumulativeProbs[255 ] - cumulativeProbs[T3];
148+ float mean0 = aux[T1-1 ]/w0;
149+ float mean1 = (aux[T2-1 ] - aux[T1])/w1;
150+ float mean2 = (aux[T3-1 ] - aux[T2])/w2;
151+ float mean3 = (aux[255 ] - aux[T3])/w3;
152+ float interClassVariance = w0*(mean0 - meanG)*(mean0 - meanG) + w1*(mean1 - meanG)*(mean1 - meanG) + w2*(mean2 - meanG)*(mean2 - meanG) + w3*(mean3 - meanG)*(mean3 - meanG);
153+ // Select threshold with highest inter-class variance
154+ if (interClassVariance > bestInterClassVariance) {
155+ bestInterClassVariance = interClassVariance;
156+ bestThreshold1 = T1;
157+ bestThreshold2 = T2;
158+ bestThreshold3 = T3;
159+ }
160+ }
161+ }
162+ }
163+ // Segment using multiple threshold
164+ auto segmentation = Image::create (input->getWidth (), input->getHeight (), TYPE_UINT8, 1 );
165+ auto segmentationAccess = segmentation->getImageAccess (ACCESS_READ_WRITE);
166+ auto inputAccess = input->getImageAccess (ACCESS_READ);
167+ for (int i = 0 ; i < input->getNrOfVoxels (); ++i) {
168+ uchar value = inputAccess->getScalarFast <uchar>(i);
169+ uchar segmentationClass;
170+ if (value < bestThreshold1) {
171+ segmentationClass = 0 ;
172+ } else if (value >= bestThreshold1 && value < bestThreshold2) {
173+ segmentationClass = 1 ;
174+ } else if (value >= bestThreshold2 && value < bestThreshold3) {
175+ segmentationClass = 2 ;
176+ } else {
177+ segmentationClass = 3 ;
178+ }
179+ segmentationAccess->setScalarFast (i, segmentationClass);
180+ }
181+ addOutputData (0 , segmentation);
182+ } else {
183+ // Not supported
184+ throw NotImplementedException ();
185+ }
186+ }
187+ }
0 commit comments