Skip to content

Commit 419230d

Browse files
committed
Added 3D support for ImageCaster and Otsu
1 parent 372937b commit 419230d

File tree

7 files changed

+147
-45
lines changed

7 files changed

+147
-45
lines changed

source/FAST/Algorithms/ImageCaster/ImageCaster.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ namespace fast {
66
ImageCaster::ImageCaster() {
77
createInputPort(0, "Image");
88
createOutputPort(0, "Image");
9-
createOpenCLProgram(Config::getKernelSourcePath() + "/Algorithms/ImageCaster/ImageCaster.cl");
9+
createOpenCLProgram(Config::getKernelSourcePath() + "/Algorithms/ImageCaster/ImageCaster2D.cl", "2D");
10+
createOpenCLProgram(Config::getKernelSourcePath() + "/Algorithms/ImageCaster/ImageCaster3D.cl", "3D");
1011
}
1112

1213
ImageCaster::ImageCaster(DataType outputType, float scaleFactor, bool normalizeFirst) : ImageCaster() {
@@ -17,8 +18,6 @@ ImageCaster::ImageCaster(DataType outputType, float scaleFactor, bool normalizeF
1718

1819
void ImageCaster::execute() {
1920
auto input = getInputData<Image>();
20-
if(input->getDimensions() == 3)
21-
throw Exception("Image caster only supports 2D for now");
2221

2322
float minimum = 0.0f;
2423
float maximum = 0.0f;
@@ -34,22 +33,23 @@ void ImageCaster::execute() {
3433

3534
auto queue = device->getCommandQueue();
3635

37-
auto inputAccess = input->getOpenCLImageAccess(ACCESS_READ, device);
38-
auto outputAccess = output->getOpenCLImageAccess(ACCESS_READ_WRITE, device);
39-
cl::Kernel kernel(getOpenCLProgram(device), "cast2D");
40-
kernel.setArg(0, *inputAccess->get2DImage());
41-
kernel.setArg(1, *outputAccess->get2DImage());
36+
Kernel kernel;
37+
if(input->getDimensions() == 2) {
38+
kernel = getKernel("cast2D", "2D");
39+
} else if(getMainOpenCLDevice()->isWritingTo3DTexturesSupported()) {
40+
kernel = getKernel("cast3D", "3D");
41+
} else {
42+
kernel = getKernel("cast3DBuffer", "3D", "-DTYPE=" + getCTypeAsString(output->getDataType()));
43+
kernel.setArg(6, input->getNrOfChannels());
44+
}
45+
kernel.setArg(0, input);
46+
kernel.setArg(1, output);
4247
kernel.setArg(2, m_scaleFactor);
4348
kernel.setArg(3, (char)(m_normalizeFirst ? 1 : 0));
4449
kernel.setArg(4, minimum);
4550
kernel.setArg(5, maximum);
4651

47-
queue.enqueueNDRangeKernel(
48-
kernel,
49-
cl::NullRange,
50-
cl::NDRange(input->getWidth(), input->getHeight()),
51-
cl::NullRange
52-
);
52+
getQueue().add(kernel, input->getSize());
5353

5454
addOutputData(0, output);
5555
}

source/FAST/Algorithms/ImageCaster/ImageCaster.cl renamed to source/FAST/Algorithms/ImageCaster/ImageCaster2D.cl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__constant sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
22

3-
float4 readImageAsFloat2D(__read_only image2d_t image, sampler_t sampler, int2 position) {
3+
inline float4 readImageAsFloat2D(__read_only image2d_t image, sampler_t sampler, int2 position) {
44
int dataType = get_image_channel_data_type(image);
55
if(dataType == CLK_FLOAT || dataType == CLK_SNORM_INT16 || dataType == CLK_UNORM_INT16) {
66
return read_imagef(image, sampler, position);
@@ -11,7 +11,7 @@ float4 readImageAsFloat2D(__read_only image2d_t image, sampler_t sampler, int2 p
1111
}
1212
}
1313

14-
void writeImageAsFloat2D(__write_only image2d_t image, int2 position, float4 value) {
14+
inline void writeImageAsFloat2D(__write_only image2d_t image, int2 position, float4 value) {
1515
int dataType = get_image_channel_data_type(image);
1616
if(dataType == CLK_FLOAT || dataType == CLK_SNORM_INT16 || dataType == CLK_UNORM_INT16) {
1717
write_imagef(image, position, value);
@@ -31,11 +31,9 @@ __kernel void cast2D(
3131
__private float maximum
3232
) {
3333
const int2 pos = {get_global_id(0), get_global_id(1)};
34+
float4 value = readImageAsFloat2D(input, sampler, pos);
3435
if(normalize == 1) {
35-
float4 value = readImageAsFloat2D(input, sampler, pos);
3636
value = (value - minimum) / (maximum - minimum);
37-
writeImageAsFloat2D(output, pos, value*scaleFactor);
38-
} else {
39-
writeImageAsFloat2D(output, pos, readImageAsFloat2D(input, sampler, pos)*scaleFactor);
4037
}
38+
writeImageAsFloat2D(output, pos, value*scaleFactor);
4139
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
__constant sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
2+
3+
inline float4 readImageAsFloat3D(__read_only image3d_t image, sampler_t sampler, int4 position) {
4+
int dataType = get_image_channel_data_type(image);
5+
if(dataType == CLK_FLOAT || dataType == CLK_SNORM_INT16 || dataType == CLK_UNORM_INT16) {
6+
return read_imagef(image, sampler, position);
7+
} else if(dataType == CLK_SIGNED_INT8 || dataType == CLK_SIGNED_INT16 || dataType == CLK_SIGNED_INT32) {
8+
return convert_float4(read_imagei(image, sampler, position));
9+
} else {
10+
return convert_float4(read_imageui(image, sampler, position));
11+
}
12+
}
13+
14+
#ifdef fast_3d_image_writes
15+
inline void writeImageAsFloat3D(__write_only image3d_t image, int4 position, float4 value) {
16+
int dataType = get_image_channel_data_type(image);
17+
if(dataType == CLK_FLOAT || dataType == CLK_SNORM_INT16 || dataType == CLK_UNORM_INT16) {
18+
write_imagef(image, position, value);
19+
} else if(dataType == CLK_SIGNED_INT8 || dataType == CLK_SIGNED_INT16 || dataType == CLK_SIGNED_INT32) {
20+
write_imagei(image, position, convert_int4(round(value)));
21+
} else {
22+
write_imageui(image, position, convert_uint4(round(value)));
23+
}
24+
}
25+
26+
__kernel void cast3D(
27+
__read_only image3d_t input,
28+
__write_only image3d_t output,
29+
__private float scaleFactor,
30+
__private char normalize,
31+
__private float minimum,
32+
__private float maximum
33+
) {
34+
const int4 pos = {get_global_id(0), get_global_id(1), get_global_id(2), 0};
35+
if(normalize == 1) {
36+
float4 value = readImageAsFloat3D(input, sampler, pos);
37+
value = (value - minimum) / (maximum - minimum);
38+
writeImageAsFloat3D(output, pos, value*scaleFactor);
39+
} else {
40+
writeImageAsFloat3D(output, pos, readImageAsFloat3D(input, sampler, pos)*scaleFactor);
41+
}
42+
}
43+
#else
44+
45+
inline void writeImageAsFloat3D(
46+
__global TYPE* output,
47+
const int4 pos,
48+
const int2 size,
49+
const int channels,
50+
const float4 value
51+
) {
52+
float valuePtr[4] = {value.x, value.y, value.z, value.w};
53+
for(int i = 0; i < channels; ++i)
54+
output[(pos.x + pos.y*size.x + pos.z*size.x*size.y)*channels + i] = valuePtr[i];
55+
}
56+
57+
__kernel void cast3DBuffer(
58+
__read_only image3d_t input,
59+
__global TYPE* output,
60+
__private const float scaleFactor,
61+
__private const char normalize,
62+
__private const float minimum,
63+
__private const float maximum,
64+
__private const int channels
65+
) {
66+
const int4 pos = {get_global_id(0), get_global_id(1), get_global_id(2), 0};
67+
float4 value = readImageAsFloat3D(input, sampler, pos);
68+
if(normalize == 1) {
69+
value = (value - minimum) / (maximum - minimum);
70+
}
71+
writeImageAsFloat3D(output, pos, (int2)(get_global_size(0), get_global_size(1)), channels, value*scaleFactor);
72+
}
73+
#endif

source/FAST/Algorithms/Thresholding/OtsuThresholding.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
namespace fast {
66

7-
std::array<int, 256> calculateGlobalHistogram(ImageAccess::pointer& access, const int width, const int height) {
7+
std::array<int, 256> calculateGlobalHistogram(ImageAccess::pointer& access, const int size) {
88
std::array<int, 256> hist{};
9-
for(int i = 0; i < width*height; ++i) {
9+
for(int i = 0; i < size; ++i) {
1010
auto pixelValue = access->getScalarFast<uchar>(i);
1111
hist[pixelValue]++;
1212
}
@@ -47,9 +47,6 @@ void OtsuThresholding::execute() {
4747
reportWarning() << "Otsu thresholding implementation only supports 1 channel images. " <<
4848
"Since your input image has more than 1 channels, only the first channel (red) is used." << reportEnd();
4949
}
50-
if(input->getDimensions() == 3) {
51-
throw Exception("Otsu thresholding is currently only implemented for 2D images");
52-
}
5350

5451
if(input->getDataType() != TYPE_UINT8) {
5552
input = ImageCaster::create(TYPE_UINT8, 255, true)
@@ -60,7 +57,7 @@ void OtsuThresholding::execute() {
6057

6158
auto access = input->getImageAccess(ACCESS_READ);
6259
// TODO Move histogram operation somewhere else?
63-
auto histogram = calculateGlobalHistogram(access, input->getWidth(), input->getHeight());
60+
auto histogram = calculateGlobalHistogram(access, input->getNrOfVoxels());
6461
access->release();
6562
auto cumulativeProbs = calculateCumulativePixelProbabilities(histogram, input->getNrOfVoxels());
6663
auto probs = calculatePixelProbabilities(histogram, input->getNrOfVoxels());
@@ -115,8 +112,9 @@ void OtsuThresholding::execute() {
115112
}
116113
}
117114
// Segment using multiple threshold
118-
auto segmentation = Image::create(input->getWidth(), input->getHeight(), TYPE_UINT8, 1);
115+
auto segmentation = Image::create(input->getSize(), TYPE_UINT8, 1);
119116
segmentation->setSpacing(input->getSpacing());
117+
SceneGraph::setParentNode(segmentation, input);
120118
auto segmentationAccess = segmentation->getImageAccess(ACCESS_READ_WRITE);
121119
auto inputAccess = input->getImageAccess(ACCESS_READ);
122120
for(int i = 0; i < input->getNrOfVoxels(); ++i) {
@@ -161,11 +159,13 @@ void OtsuThresholding::execute() {
161159
}
162160
}
163161
// Segment using multiple threshold
164-
auto segmentation = Image::create(input->getWidth(), input->getHeight(), TYPE_UINT8, 1);
162+
auto segmentation = Image::create(input->getSize(), TYPE_UINT8, 1);
163+
segmentation->setSpacing(input->getSpacing());
164+
SceneGraph::setParentNode(segmentation, input);
165165
auto segmentationAccess = segmentation->getImageAccess(ACCESS_READ_WRITE);
166166
auto inputAccess = input->getImageAccess(ACCESS_READ);
167167
for(int i = 0; i < input->getNrOfVoxels(); ++i) {
168-
uchar value = inputAccess->getScalarFast<uchar>(i);
168+
auto value = inputAccess->getScalarFast<uchar>(i);
169169
uchar segmentationClass;
170170
if(value < bestThreshold1) {
171171
segmentationClass = 0;

source/FAST/Algorithms/Thresholding/OtsuThresholdingTests.cpp

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,39 @@
77
using namespace fast;
88

99
TEST_CASE("Otsu thresholding 2D uint8 image", "[fast][OtsuThresholding][visual]") {
10-
auto importer = ImageFileImporter::create(Config::getTestDataPath() + "US/US-2D.jpg");
11-
auto segment = OtsuThresholding::create()->connect(importer);
12-
Display2DArgs args;
13-
args.image = importer;
14-
args.segmentation = segment;
15-
args.timeout = 1000;
16-
display2D(args);
10+
for(int c = 2; c < 5; ++c) {
11+
auto importer = ImageFileImporter::create(Config::getTestDataPath() + "US/US-2D.jpg");
12+
auto segment = OtsuThresholding::create(c)->connect(importer);
13+
Display2DArgs args;
14+
args.image = importer;
15+
args.segmentation = segment;
16+
args.timeout = 1000;
17+
display2D(args);
18+
}
1719
}
1820

19-
2021
TEST_CASE("Otsu thresholding 2D float image", "[fast][OtsuThresholding][visual]") {
21-
auto importer = ImageFileImporter::create(Config::getTestDataPath() + "US/US-2D.jpg");
22-
auto caster = ImageCaster::create(TYPE_FLOAT, 2.0f)->connect(importer);
23-
auto segment = OtsuThresholding::create()->connect(caster);
24-
Display2DArgs args;
25-
args.image = importer;
26-
args.segmentation = segment;
27-
args.timeout = 1000;
28-
display2D(args);
22+
for(int c = 2; c < 5; ++c) {
23+
auto importer = ImageFileImporter::create(Config::getTestDataPath() + "US/US-2D.jpg");
24+
auto caster = ImageCaster::create(TYPE_FLOAT, 2.0f)->connect(importer);
25+
auto segment = OtsuThresholding::create(c)->connect(caster);
26+
Display2DArgs args;
27+
args.image = importer;
28+
args.segmentation = segment;
29+
args.timeout = 1000;
30+
display2D(args);
31+
}
32+
}
33+
34+
TEST_CASE("Otsu thresholding 3D integer image", "[fast][OtsuThresholding][visual]") {
35+
for(int c = 2; c < 5; ++c) {
36+
auto importer = ImageFileImporter::create(Config::getTestDataPath() + "CT/CT-Thorax.mhd");
37+
auto segment = OtsuThresholding::create(c)->connect(importer);
38+
Display3DArgs args;
39+
args.image = importer;
40+
args.segmentation = segment;
41+
args.timeout = 1000;
42+
display3D(args);
43+
}
2944
}
45+

source/FAST/OpenCLProgram.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,12 @@ Kernel::Kernel(cl::Kernel clKernel, OpenCLDevice::pointer device) {
185185
// TODO Handle possible need for recompiling cached kernels
186186
throw Exception("OpenCL exception caught when trying to get kernel argument information: " + std::string(e.what()) + "(" + getCLErrorString(e.err()) + "). You might need to recompile your OpenCL code, delete the cache.");
187187
}
188+
m_initialized = true;
188189
}
189190

190191
cl::Kernel Kernel::getHandle() const {
192+
if(!m_initialized)
193+
throw Exception("Kernel object has not been initialized. Did you forgot = getKernel()?");
191194
return m_kernel;
192195
}
193196

@@ -240,10 +243,14 @@ KernelArgument Kernel::getArg(std::string name) const {
240243
}
241244

242245
bool Kernel::allArgumentsGotValue() const {
246+
if(!m_initialized)
247+
throw Exception("Kernel object has not been initialized. Did you forgot = getKernel()?");
243248
return m_argGotValue.size() == getNumberOfArgs();
244249
}
245250

246251
std::vector<std::string> Kernel::getArgumentsWithoutValue() const {
252+
if(!m_initialized)
253+
throw Exception("Kernel object has not been initialized. Did you forgot = getKernel()?");
247254
std::vector<std::string> list;
248255
for(int i = 0; i < getNumberOfArgs(); ++i) {
249256
if(m_argGotValue.count(i) == 0)
@@ -274,10 +281,16 @@ void Kernel::setTensorArg(std::string name, std::shared_ptr<fast::Tensor> tensor
274281
}
275282

276283
void Kernel::checkIndex(int index) const {
284+
if(!m_initialized)
285+
throw Exception("Kernel object has not been initialized. Did you forgot = getKernel()?");
277286
if(index >= getNumberOfArgs() || index < 0)
278287
throw Exception("Kernel does not have an argument with index " + std::to_string(index) + ", number of arguments is: " + std::to_string(getNumberOfArgs()));
279288
}
280289

290+
Kernel::Kernel() {
291+
m_initialized = false;
292+
}
293+
281294
OpenCLBuffer::OpenCLBuffer(std::size_t size, OpenCLDevice::pointer device, KernelMemoryAccess kernelAccess,
282295
HostMemoryAccess hostAccess, const void *data) {
283296
std::map<KernelMemoryAccess, cl_mem_flags> kernelMemoryAccessMap = {

source/FAST/OpenCLProgram.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class OpenCLBuffer;
8282
class FAST_EXPORT Kernel {
8383
public:
8484
explicit Kernel(cl::Kernel clKernel, OpenCLDevice::pointer device);
85+
Kernel();
8586
cl::Kernel getHandle() const;
8687

8788
#ifndef SWIG
@@ -165,6 +166,7 @@ class FAST_EXPORT Kernel {
165166
std::map<int, KernelArgument> m_argInfoByIndex;
166167
std::set<int> m_argGotValue;
167168
std::vector<OpenCLBuffer> m_buffers;
169+
bool m_initialized = false;
168170
};
169171

170172
template <class T>

0 commit comments

Comments
 (0)