Skip to content

Commit 05096fd

Browse files
committed
feat: allow returning images in json in base64 format
1 parent 28060f9 commit 05096fd

File tree

7 files changed

+165
-9
lines changed

7 files changed

+165
-9
lines changed

src/dto/ddtypes.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ namespace dd
3131

3232
const oatpp::ClassId GpuIdsClass::CLASS_ID("GpuIds");
3333

34+
const oatpp::ClassId ImageClass::CLASS_ID("Image");
35+
3436
template <>
3537
const oatpp::ClassId DTOVectorClass<double>::CLASS_ID("vector<double>");
3638

src/dto/ddtypes.hpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@
2222
#ifndef DD_DTO_TYPES_HPP
2323
#define DD_DTO_TYPES_HPP
2424

25+
#include <opencv2/opencv.hpp>
26+
2527
#include "oatpp/core/Types.hpp"
2628
#include "oatpp/parser/json/mapping/ObjectMapper.hpp"
2729
#include "apidata.h"
30+
#include "utils/cv_utils.hpp"
2831

2932
namespace dd
3033
{
@@ -46,10 +49,51 @@ namespace dd
4649
}
4750
};
4851

52+
struct VImage
53+
{
54+
cv::Mat _img;
55+
#ifdef USE_CUDA_CV
56+
cv::cuda::GpuMat _cuda_img;
57+
#endif
58+
std::string _ext = ".png";
59+
60+
VImage(const cv::Mat &img, const std::string &ext = ".png")
61+
: _img(img), _ext(ext)
62+
{
63+
}
64+
#ifdef USE_CUDA_CV
65+
VImage(const cv::cuda::GpuMat &cuda_img, const std::string &ext = ".png")
66+
: _cuda_img(cuda_img), _ext(ext)
67+
{
68+
}
69+
#endif
70+
bool is_cuda() const
71+
{
72+
#ifdef USE_CUDA_CV
73+
return !_cuda_img.empty();
74+
#else
75+
return false;
76+
#endif
77+
}
78+
79+
/** get image on CPU whether it's on GPU or not */
80+
const cv::Mat &get_img()
81+
{
82+
#ifdef USE_CUDA_CV
83+
if (is_cuda())
84+
{
85+
_cuda_img.download(_img);
86+
}
87+
#endif
88+
return _img;
89+
}
90+
};
91+
4992
namespace __class
5093
{
5194
class APIDataClass;
5295
class GpuIdsClass;
96+
class ImageClass;
5397
template <typename T> class DTOVectorClass;
5498
}
5599

@@ -59,6 +103,8 @@ namespace dd
59103
typedef oatpp::data::mapping::type::Primitive<VGpuIds,
60104
__class::GpuIdsClass>
61105
GpuIds;
106+
typedef oatpp::data::mapping::type::Primitive<VImage, __class::ImageClass>
107+
DTOImage;
62108
template <typename T>
63109
using DTOVector
64110
= oatpp::data::mapping::type::Primitive<std::vector<T>,
@@ -89,6 +135,18 @@ namespace dd
89135
}
90136
};
91137

138+
class ImageClass
139+
{
140+
public:
141+
static const oatpp::ClassId CLASS_ID;
142+
143+
static oatpp::Type *getType()
144+
{
145+
static oatpp::Type type(CLASS_ID);
146+
return &type;
147+
}
148+
};
149+
92150
template <typename T> class DTOVectorClass
93151
{
94152
public:
@@ -113,6 +171,9 @@ namespace dd
113171
{
114172
(void)type;
115173
(void)deserializer;
174+
// XXX: this has a failure case if the stream contains excaped "{" or "}"
175+
// Since this is a temporary workaround until we use DTO everywhere, it
176+
// might not be required to be fixed
116177
if (caret.isAtChar('{'))
117178
{
118179
auto start = caret.getCurrData();
@@ -221,6 +282,30 @@ namespace dd
221282
}
222283
}
223284

285+
static inline oatpp::Void
286+
imageDeserialize(oatpp::parser::json::mapping::Deserializer *deserializer,
287+
oatpp::parser::Caret &caret,
288+
const oatpp::Type *const type)
289+
{
290+
(void)type;
291+
auto str_base64
292+
= deserializer->deserialize(caret, oatpp::String::Class::getType())
293+
.cast<oatpp::String>();
294+
return DTOImage(VImage{ cv_utils::base64_to_image(*str_base64) });
295+
}
296+
297+
static inline void
298+
imageSerialize(oatpp::parser::json::mapping::Serializer *serializer,
299+
oatpp::data::stream::ConsistentOutputStream *stream,
300+
const oatpp::Void &obj)
301+
{
302+
(void)serializer;
303+
auto img_dto = obj.cast<DTOImage>();
304+
std::string encoded
305+
= cv_utils::image_to_base64(img_dto->get_img(), img_dto->_ext);
306+
stream->writeSimple(encoded);
307+
}
308+
224309
// Inspired by oatpp json deserializer
225310
template <typename T> inline T readVecElement(oatpp::parser::Caret &caret);
226311

src/dto/predict_out.hpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,24 @@ namespace dd
116116

117117
DTO_FIELD_INFO(vals)
118118
{
119-
info->description = "[Unsupervised] Array containing model output "
120-
"values. Can be in different formats: double, "
121-
"binarized double, booleans, binarized string";
119+
info->description
120+
= "[Unsupervised] Array containing model output "
121+
"values. Can be in different formats: double, "
122+
"binarized double, booleans, binarized string, base64 image";
122123
}
123124
DTO_FIELD(Any, vals);
125+
126+
DTO_FIELD_INFO(images)
127+
{
128+
info->description
129+
= "[Unsupervised] Array of images returned by the model";
130+
}
131+
DTO_FIELD(Vector<DTOImage>, images);
132+
133+
DTO_FIELD_INFO(imgsize)
134+
{
135+
info->description = "[Unsupervised] Image size";
136+
}
124137
DTO_FIELD(Object<Dimensions>, imgsize);
125138

126139
DTO_FIELD_INFO(confidences)
@@ -140,6 +153,7 @@ namespace dd
140153
DTO_FIELD(String, index_uri);
141154

142155
public:
156+
// XXX: Legacy & deprecated
143157
std::vector<cv::Mat> _images; /**<allow to pass images in the DTO */
144158
};
145159

src/unsupervisedoutputconnector.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
#ifndef UNSUPERVISEDOUTPUTCONNECTOR_H
2323
#define UNSUPERVISEDOUTPUTCONNECTOR_H
2424

25+
#include <vector>
26+
#include <map>
27+
2528
#include "dto/predict_out.hpp"
2629

2730
namespace dd
@@ -119,6 +122,9 @@ namespace dd
119122
_bool_binarized = ad_out.get("bool_binarized").get<bool>();
120123
else if (ad_out.has("string_binarized"))
121124
_string_binarized = ad_out.get("string_binarized").get<bool>();
125+
126+
if (ad_out.has("encoding"))
127+
_image_encoding = ad_out.get("encoding").get<std::string>();
122128
}
123129

124130
void set_results(std::vector<UnsupervisedResult> &&results)
@@ -321,8 +327,16 @@ namespace dd
321327
auto pred_dto = DTO::Prediction::createShared();
322328
pred_dto->uri = _vvres.at(i)._uri.c_str();
323329
if (_vvres.at(i)._images.size() != 0)
324-
pred_dto->_images = _vvres.at(i)._images;
325-
if (_bool_binarized)
330+
{
331+
// XXX: legacy
332+
pred_dto->_images = _vvres.at(i)._images;
333+
334+
pred_dto->images = oatpp::Vector<DTO::DTOImage>::createShared();
335+
for (auto &image : _vvres.at(i)._images)
336+
pred_dto->images->push_back(
337+
DTO::VImage{ image, _image_encoding });
338+
}
339+
else if (_bool_binarized)
326340
pred_dto->vals
327341
= DTO::DTOVector<bool>(std::move(_vvres.at(i)._bvals));
328342
else if (_string_binarized)
@@ -368,6 +382,8 @@ namespace dd
368382
= false; /**< boolean binary representation of output values. */
369383
bool _string_binarized = false; /**< boolean string as binary
370384
representation of output values. */
385+
std::string _image_encoding
386+
= ".png"; /**< encoding used for output images */
371387
#ifdef USE_SIMSEARCH
372388
int _search_nn = 10; /**< default nearest neighbors per search. */
373389
#endif

src/utils/oatpp.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ namespace dd
3838
DTO::apiDataDeserialize);
3939
deser->setDeserializerMethod(DTO::GpuIds::Class::CLASS_ID,
4040
DTO::gpuIdsDeserialize);
41+
deser->setDeserializerMethod(DTO::DTOImage::Class::CLASS_ID,
42+
DTO::imageDeserialize);
4143
deser->setDeserializerMethod(DTO::DTOVector<double>::Class::CLASS_ID,
4244
DTO::vectorDeserialize<double>);
4345
deser->setDeserializerMethod(DTO::DTOVector<uint8_t>::Class::CLASS_ID,
@@ -49,6 +51,8 @@ namespace dd
4951
DTO::apiDataSerialize);
5052
ser->setSerializerMethod(DTO::GpuIds::Class::CLASS_ID,
5153
DTO::gpuIdsSerialize);
54+
ser->setSerializerMethod(DTO::DTOImage::Class::CLASS_ID,
55+
DTO::imageSerialize);
5256
ser->setSerializerMethod(DTO::DTOVector<double>::Class::CLASS_ID,
5357
DTO::vectorSerialize<double>);
5458
ser->setSerializerMethod(DTO::DTOVector<uint8_t>::Class::CLASS_ID,
@@ -191,6 +195,13 @@ namespace dd
191195
jval.PushBack(dto_gpuid->_ids[i], jdoc.GetAllocator());
192196
}
193197
}
198+
else if (polymorph.getValueType() == DTO::DTOImage::Class::getType())
199+
{
200+
auto dto_img = polymorph.cast<DTO::DTOImage>();
201+
std::string img_str
202+
= cv_utils::image_to_base64(dto_img->get_img(), dto_img->_ext);
203+
jval.SetString(img_str.c_str(), jdoc.GetAllocator());
204+
}
194205
else if (polymorph.getValueType()->classId.id
195206
== oatpp::data::mapping::type::__class::AbstractVector::
196207
CLASS_ID.id

tests/ut-tensorrtapi.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,31 @@ TEST(tensorrtapi, service_predict_gan_onnx)
271271
ASSERT_TRUE(jd["body"]["predictions"][0]["vals"].IsArray());
272272
ASSERT_EQ(jd["body"]["predictions"][0]["vals"].Size(), 360 * 360 * 3);
273273

274+
// predict to image
275+
jpredictstr
276+
= "{\"service\":\"" + sname
277+
+ "\",\"parameters\":{\"input\":{\"height\":360,"
278+
"\"width\":360,\"rgb\":true,\"scale\":0.00392,\"mean\":[0.5,0.5,0.5]"
279+
",\"std\":[0.5,0.5,0.5]},\"output\":{\"image\":true},\"mllib\":{"
280+
"\"extract_layer\":\"last\"}},\"data\":[\""
281+
+ cyclegan_onnx_repo + "horse.jpg\"]}";
282+
joutstr = japi.jrender(japi.service_predict(jpredictstr));
283+
jd = JDoc();
284+
// std::cout << "joutstr=" << joutstr << std::endl;
285+
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
286+
ASSERT_TRUE(!jd.HasParseError());
287+
ASSERT_EQ(200, jd["status"]["code"]);
288+
ASSERT_TRUE(jd["body"]["predictions"].IsArray());
289+
ASSERT_TRUE(jd["body"]["predictions"][0]["images"].IsArray());
290+
ASSERT_EQ(jd["body"]["predictions"][0]["images"].Size(), 1);
291+
// png image
292+
std::string base64_img
293+
= jd["body"]["predictions"][0]["images"][0].GetString();
294+
// may be small differences between machines, versions of libpng/jpeg?
295+
ASSERT_NEAR(base64_img.size(), 388292, 100);
296+
// cv::imwrite("onnx_gan_base64.jpg", cv_utils::base64_to_image(base64_img));
297+
298+
// delete
274299
ASSERT_TRUE(fileops::file_exists(cyclegan_onnx_repo + "TRTengine_arch"
275300
+ get_trt_archi() + "_fp16_bs1"));
276301

tests/ut-torchapi.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ static std::string iterations_ttransformer_cpu = "100";
124124
static std::string iterations_ttransformer_gpu = "1000";
125125

126126
static std::string iterations_resnet50 = "200";
127+
/// different values to mitigate failure due to randomness
128+
static std::string iterations_resnet50_split = "300";
127129
static std::string iterations_vit = "200";
128130
static std::string iterations_detection = "200";
129131
static std::string iterations_deeplabv3 = "200";
@@ -831,7 +833,7 @@ TEST(torchapi, service_train_images_split)
831833
std::string jtrainstr
832834
= "{\"service\":\"imgserv\",\"async\":false,\"parameters\":{"
833835
"\"mllib\":{\"solver\":{\"iterations\":"
834-
+ iterations_resnet50 + ",\"base_lr\":" + torch_lr
836+
+ iterations_resnet50_split + ",\"base_lr\":" + torch_lr
835837
+ ",\"iter_size\":4,\"solver_type\":\"ADAM\",\"test_"
836838
"interval\":200},\"net\":{\"batch_size\":4},\"nclasses\":2,"
837839
"\"resume\":false},"
@@ -846,7 +848,8 @@ TEST(torchapi, service_train_images_split)
846848
ASSERT_TRUE(!jd.HasParseError());
847849
ASSERT_EQ(201, jd["status"]["code"]);
848850

849-
ASSERT_TRUE(jd["body"]["measure"]["iteration"] == 200) << "iterations";
851+
int it_count = std::stoi(iterations_resnet50_split);
852+
ASSERT_TRUE(jd["body"]["measure"]["iteration"] == it_count) << "iterations";
850853
ASSERT_TRUE(jd["body"]["measure"]["train_loss"].GetDouble() <= 3.0)
851854
<< "loss";
852855

@@ -862,9 +865,9 @@ TEST(torchapi, service_train_images_split)
862865
remove(ff.c_str());
863866
}
864867
ASSERT_TRUE(!fileops::file_exists(resnet50_train_repo + "checkpoint-"
865-
+ iterations_resnet50 + ".ptw"));
868+
+ iterations_resnet50_split + ".ptw"));
866869
ASSERT_TRUE(!fileops::file_exists(resnet50_train_repo + "checkpoint-"
867-
+ iterations_resnet50 + ".pt"));
870+
+ iterations_resnet50_split + ".pt"));
868871
fileops::clear_directory(resnet50_train_repo + "train.lmdb");
869872
fileops::clear_directory(resnet50_train_repo + "test_0.lmdb");
870873
fileops::remove_dir(resnet50_train_repo + "train.lmdb");

0 commit comments

Comments
 (0)