Skip to content

Commit 372937b

Browse files
committed
Added OtsuThresholding 2D implementation
1 parent 4cb1a45 commit 372937b

File tree

4 files changed

+268
-1
lines changed

4 files changed

+268
-1
lines changed

source/FAST/Algorithms/Thresholding/CMakeLists.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
fast_add_sources(
22
BinaryThresholding.cpp
33
BinaryThresholding.hpp
4+
OtsuThresholding.cpp
5+
OtsuThresholding.hpp
6+
)
7+
fast_add_process_object(BinaryThresholding BinaryThresholding.hpp)
8+
fast_add_process_object(OtsuThresholding OtsuThresholding.hpp)
9+
10+
fast_add_test_sources(
11+
OtsuThresholdingTests.cpp
412
)
5-
fast_add_process_object(BinaryThresholding BinaryThresholding.hpp)
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
#include <FAST/Algorithms/ImageCaster/ImageCaster.hpp>
2+
#include "OtsuThresholding.hpp"
3+
#include "BinaryThresholding.hpp"
4+
5+
namespace fast {
6+
7+
std::array<int, 256> calculateGlobalHistogram(ImageAccess::pointer& access, const int width, const int height) {
8+
std::array<int, 256> hist{};
9+
for(int i = 0; i < width*height; ++i) {
10+
auto pixelValue = access->getScalarFast<uchar>(i);
11+
hist[pixelValue]++;
12+
}
13+
return hist;
14+
}
15+
16+
17+
std::array<float, 256> calculatePixelProbabilities(std::array<int, 256> histogram, const int pixels) {
18+
std::array<float, 256> probs;
19+
for(int i = 0; i < 256; ++i)
20+
probs[i] = (float)histogram[i] / (float)pixels;
21+
22+
return probs;
23+
}
24+
25+
std::array<float, 256> calculateCumulativePixelProbabilities(std::array<int, 256> histogram, const int pixels) {
26+
std::array<float, 256> cumulativeProbs;
27+
cumulativeProbs[0] = (float)histogram[0] / (float)pixels;
28+
for(int i = 1; i < 256; ++i)
29+
cumulativeProbs[i] = cumulativeProbs[i-1] + (float)histogram[i] / (float)pixels;
30+
31+
return cumulativeProbs;
32+
}
33+
34+
35+
OtsuThresholding::OtsuThresholding(int numberOfClasses) {
36+
createInputPort(0);
37+
createOutputPort(0);
38+
if(numberOfClasses > 4 || numberOfClasses < 2)
39+
throw Exception("Otsu thresholding implementation only supports 2, 3, or 4 classes");
40+
m_thresholdCount = numberOfClasses - 1;
41+
}
42+
43+
void OtsuThresholding::execute() {
44+
auto input = getInputData<Image>();
45+
46+
if(input->getNrOfChannels() > 1) {
47+
reportWarning() << "Otsu thresholding implementation only supports 1 channel images. " <<
48+
"Since your input image has more than 1 channels, only the first channel (red) is used." << reportEnd();
49+
}
50+
if(input->getDimensions() == 3) {
51+
throw Exception("Otsu thresholding is currently only implemented for 2D images");
52+
}
53+
54+
if(input->getDataType() != TYPE_UINT8) {
55+
input = ImageCaster::create(TYPE_UINT8, 255, true)
56+
->connect(input)
57+
->runAndGetOutputData<Image>();
58+
// TODO handle issue where bins 0 and 255 become less probable due to rounding..
59+
}
60+
61+
auto access = input->getImageAccess(ACCESS_READ);
62+
// TODO Move histogram operation somewhere else?
63+
auto histogram = calculateGlobalHistogram(access, input->getWidth(), input->getHeight());
64+
access->release();
65+
auto cumulativeProbs = calculateCumulativePixelProbabilities(histogram, input->getNrOfVoxels());
66+
auto probs = calculatePixelProbabilities(histogram, input->getNrOfVoxels());
67+
68+
std::array<float, 256> aux;
69+
aux[0] = 0;
70+
for(int i = 1; i < 256; ++i) {
71+
aux[i] = aux[i-1] + i*probs[i];
72+
}
73+
74+
if(m_thresholdCount == 1) {
75+
float bestInterClassVariance = 0;
76+
int bestThreshold;
77+
// Iterate over all possible thresholds
78+
for(int T = 1; T < 256; ++T) {
79+
// Calculate inter-class variance
80+
float w0 = cumulativeProbs[T-1];
81+
float w1 = cumulativeProbs[255] - cumulativeProbs[T];
82+
float mean0 = aux[T-1]/w0;
83+
float mean1 = (aux[255] - aux[T])/w1;
84+
float interClassVariance = w0*w1*(mean0 - mean1)*(mean0 - mean1);
85+
// Select threshold with highest inter-class variance
86+
if(interClassVariance > bestInterClassVariance) {
87+
bestInterClassVariance = interClassVariance;
88+
bestThreshold = T;
89+
}
90+
}
91+
// Segment using threshold
92+
addOutputData(0, BinaryThresholding::create(bestThreshold)->connect(input)->runAndGetOutputData<Image>());
93+
} else if(m_thresholdCount == 2) {
94+
const float meanG = aux[255];
95+
float bestInterClassVariance = 0;
96+
int bestThreshold1;
97+
int bestThreshold2;
98+
// Iterate over all possible thresholds
99+
for(int T1 = 1; T1 < 256-1; ++T1) {
100+
for(int T2 = T1+1; T2 < 256; ++T2) {
101+
// Calculate inter-class variance
102+
float w0 = cumulativeProbs[T1-1];
103+
float w1 = cumulativeProbs[T2-1] - cumulativeProbs[T1];
104+
float w2 = cumulativeProbs[255] - cumulativeProbs[T2];
105+
float mean0 = aux[T1-1]/w0;
106+
float mean1 = (aux[T2-1] - aux[T1])/w1;
107+
float mean2 = (aux[255] - aux[T2])/w2;
108+
float interClassVariance = w0*(mean0 - meanG)*(mean0 - meanG) + w1*(mean1 - meanG)*(mean1 - meanG) + w2*(mean2 - meanG)*(mean2 - meanG);
109+
// Select threshold with highest inter-class variance
110+
if(interClassVariance > bestInterClassVariance) {
111+
bestInterClassVariance = interClassVariance;
112+
bestThreshold1 = T1;
113+
bestThreshold2 = T2;
114+
}
115+
}
116+
}
117+
// Segment using multiple threshold
118+
auto segmentation = Image::create(input->getWidth(), input->getHeight(), TYPE_UINT8, 1);
119+
segmentation->setSpacing(input->getSpacing());
120+
auto segmentationAccess = segmentation->getImageAccess(ACCESS_READ_WRITE);
121+
auto inputAccess = input->getImageAccess(ACCESS_READ);
122+
for(int i = 0; i < input->getNrOfVoxels(); ++i) {
123+
auto value = inputAccess->getScalarFast<uchar>(i);
124+
uchar segmentationClass = 0;
125+
if(value >= bestThreshold1 && value < bestThreshold2) {
126+
segmentationClass = 1;
127+
} else if(value >= bestThreshold2) {
128+
segmentationClass = 2;
129+
}
130+
segmentationAccess->setScalarFast(i, segmentationClass);
131+
}
132+
addOutputData(0, segmentation);
133+
} else if(m_thresholdCount == 3) {
134+
const float meanG = aux[255];
135+
float bestInterClassVariance = 0;
136+
int bestThreshold1;
137+
int bestThreshold2;
138+
int bestThreshold3;
139+
// Iterate over all possible thresholds
140+
for(int T1 = 1; T1 < 256-2; ++T1) {
141+
for(int T2 = T1+1; T2 < 256-1; ++T2) {
142+
for(int T3 = T2+1; T3 < 256; ++T3) {
143+
// Calculate inter-class variance
144+
float w0 = cumulativeProbs[T1-1];
145+
float w1 = cumulativeProbs[T2-1] - cumulativeProbs[T1];
146+
float w2 = cumulativeProbs[T3-1] - cumulativeProbs[T2];
147+
float w3 = cumulativeProbs[255] - cumulativeProbs[T3];
148+
float mean0 = aux[T1-1]/w0;
149+
float mean1 = (aux[T2-1] - aux[T1])/w1;
150+
float mean2 = (aux[T3-1] - aux[T2])/w2;
151+
float mean3 = (aux[255] - aux[T3])/w3;
152+
float interClassVariance = w0*(mean0 - meanG)*(mean0 - meanG) + w1*(mean1 - meanG)*(mean1 - meanG) + w2*(mean2 - meanG)*(mean2 - meanG) + w3*(mean3 - meanG)*(mean3 - meanG);
153+
// Select threshold with highest inter-class variance
154+
if(interClassVariance > bestInterClassVariance) {
155+
bestInterClassVariance = interClassVariance;
156+
bestThreshold1 = T1;
157+
bestThreshold2 = T2;
158+
bestThreshold3 = T3;
159+
}
160+
}
161+
}
162+
}
163+
// Segment using multiple threshold
164+
auto segmentation = Image::create(input->getWidth(), input->getHeight(), TYPE_UINT8, 1);
165+
auto segmentationAccess = segmentation->getImageAccess(ACCESS_READ_WRITE);
166+
auto inputAccess = input->getImageAccess(ACCESS_READ);
167+
for(int i = 0; i < input->getNrOfVoxels(); ++i) {
168+
uchar value = inputAccess->getScalarFast<uchar>(i);
169+
uchar segmentationClass;
170+
if(value < bestThreshold1) {
171+
segmentationClass = 0;
172+
} else if(value >= bestThreshold1 && value < bestThreshold2) {
173+
segmentationClass = 1;
174+
} else if(value >= bestThreshold2 && value < bestThreshold3) {
175+
segmentationClass = 2;
176+
} else {
177+
segmentationClass = 3;
178+
}
179+
segmentationAccess->setScalarFast(i, segmentationClass);
180+
}
181+
addOutputData(0, segmentation);
182+
} else {
183+
// Not supported
184+
throw NotImplementedException();
185+
}
186+
}
187+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#pragma once
2+
3+
#include <FAST/ProcessObject.hpp>
4+
5+
namespace fast {
6+
7+
/**
8+
* @brief Otsu thresholding segmentation method
9+
*
10+
* Automatically determines the threshold to use for segmentation
11+
* using Otsu's method. Supports multi-class thresholding by changing the
12+
* number of classes argument in the constructor (default is 2 classes = 1 threshold).
13+
* Maximum number of classes is currently 4.
14+
*
15+
* The implementation uses 256 bins when calculating the histogram.
16+
* Input images that are not of UINT8 type are normalized to [0, 255] range and cast to UINT8 before processing.
17+
*
18+
* Inputs:
19+
* - 0: Image (only first channel is used, multi-channel support not implemented).
20+
*
21+
* Outputs:
22+
* - 0: Image segmentation
23+
*
24+
* @todo Add methods for getting thresholds.
25+
* @todo 3D support
26+
* @todo Multi channel support
27+
*
28+
* @ingroup segmentation
29+
*/
30+
class FAST_EXPORT OtsuThresholding : public ProcessObject {
31+
FAST_PROCESS_OBJECT(OtsuThresholding)
32+
public:
33+
/**
34+
* @brief Create instance
35+
* @param numberOfClasses Numbef of classes to use, minimum 2, maximum 4. The number of thresholds is classes - 1.
36+
* @return instance
37+
*/
38+
FAST_CONSTRUCTOR(OtsuThresholding, int, numberOfClasses,= 2)
39+
private:
40+
void execute() override;
41+
int m_thresholdCount;
42+
};
43+
44+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include <FAST/Testing.hpp>
2+
#include <FAST/Importers/ImageFileImporter.hpp>
3+
#include <FAST/Visualization/Shortcuts.hpp>
4+
#include <FAST/Algorithms/ImageCaster/ImageCaster.hpp>
5+
#include "OtsuThresholding.hpp"
6+
7+
using namespace fast;
8+
9+
TEST_CASE("Otsu thresholding 2D uint8 image", "[fast][OtsuThresholding][visual]") {
10+
auto importer = ImageFileImporter::create(Config::getTestDataPath() + "US/US-2D.jpg");
11+
auto segment = OtsuThresholding::create()->connect(importer);
12+
Display2DArgs args;
13+
args.image = importer;
14+
args.segmentation = segment;
15+
args.timeout = 1000;
16+
display2D(args);
17+
}
18+
19+
20+
TEST_CASE("Otsu thresholding 2D float image", "[fast][OtsuThresholding][visual]") {
21+
auto importer = ImageFileImporter::create(Config::getTestDataPath() + "US/US-2D.jpg");
22+
auto caster = ImageCaster::create(TYPE_FLOAT, 2.0f)->connect(importer);
23+
auto segment = OtsuThresholding::create()->connect(caster);
24+
Display2DArgs args;
25+
args.image = importer;
26+
args.segmentation = segment;
27+
args.timeout = 1000;
28+
display2D(args);
29+
}

0 commit comments

Comments
 (0)