Skip to content

Commit 28060f9

Browse files
committed
test: add test for torch inference in base64
1 parent edb28c1 commit 28060f9

File tree

3 files changed

+65
-6
lines changed

3 files changed

+65
-6
lines changed

src/backends/caffe/caffelib.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3551,6 +3551,9 @@ namespace dd
35513551
= _net->blob_by_name("cont_seq");
35523552
// cont_seq is TxN
35533553

3554+
#pragma GCC diagnostic push
3555+
#pragma GCC diagnostic ignored "-Warray-bounds"
3556+
35543557
CSVTSCaffeInputFileConn *ic
35553558
= reinterpret_cast<CSVTSCaffeInputFileConn *>(&inputc);
35563559

@@ -3638,6 +3641,7 @@ namespace dd
36383641
series.push_back(ts);
36393642
}
36403643
}
3644+
#pragma GCC diagnostic pop
36413645
}
36423646
else // classification
36433647
{

src/utils/cv_utils.hpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@
2323
#define DD_UTILS_CVUTILS_HPP
2424

2525
#include <vector>
26+
#include <opencv2/opencv.hpp>
27+
#include "ext/base64/base64.h"
2628

2729
namespace dd
2830
{
2931
namespace cv_utils
3032
{
3133
/** Convert an int fourcc (from a video) to string format */
32-
std::string fourcc_to_string(int fourcc)
34+
inline std::string fourcc_to_string(int fourcc)
3335
{
3436
union
3537
{
@@ -44,6 +46,33 @@ namespace dd
4446
(i32_c.c[2] >= ' ' && i32_c.c[2] < 128) ? i32_c.c[2] : '?',
4547
(i32_c.c[3] >= ' ' && i32_c.c[3] < 128) ? i32_c.c[3] : '?');
4648
}
49+
50+
inline cv::Mat base64_to_image(const std::string &str_base64)
51+
{
52+
std::string img_str;
53+
if (!Base64::Decode(str_base64, &img_str))
54+
throw std::runtime_error("Image could not be decoded");
55+
56+
std::vector<unsigned char> vdat(img_str.begin(), img_str.end());
57+
cv::Mat img
58+
= cv::Mat(cv::imdecode(cv::Mat(vdat, false), cv::IMREAD_UNCHANGED));
59+
return img;
60+
}
61+
62+
/** Convert image to base64 string */
63+
inline std::string image_to_base64(const cv::Mat &mat,
64+
const std::string &ext)
65+
{
66+
std::vector<uint8_t> buffer;
67+
cv::imencode(ext, mat, buffer);
68+
69+
// encode to base64
70+
std::string byte_str(buffer.begin(), buffer.end());
71+
std::string encoded;
72+
if (!Base64::Encode(byte_str, &encoded))
73+
throw std::runtime_error("Image could not be encoded");
74+
return encoded;
75+
}
4776
}
4877
}
4978

tests/ut-torchapi.cc

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,23 @@
2020
* You should have received a copy of the GNU Lesser General Public License
2121
* along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
2222
*/
23-
24-
#include "deepdetect.h"
25-
#include "jsonapi.h"
26-
#include "txtinputfileconn.h"
2723
#include <gtest/gtest.h>
2824
#include <stdio.h>
2925
#include <iostream>
3026
#include <numeric>
31-
#include "backends/torch/native/templates/nbeats.h"
27+
#pragma GCC diagnostic push
28+
#pragma GCC diagnostic ignored "-Wunused-parameter"
29+
#pragma GCC diagnostic ignored "-Wunused-variable"
3230
#include <torch/torch.h>
31+
#pragma GCC diagnostic pop
3332
#include <rapidjson/istreamwrapper.h>
3433

34+
#include "deepdetect.h"
35+
#include "jsonapi.h"
36+
#include "txtinputfileconn.h"
37+
#include "utils/cv_utils.hpp"
38+
#include "backends/torch/native/templates/nbeats.h"
39+
3540
using namespace dd;
3641

3742
static std::string ok_str = "{\"status\":{\"code\":200,\"msg\":\"OK\"}}";
@@ -383,6 +388,7 @@ TEST(torchapi, service_predict_object_detection)
383388
"\"best_bbox\":3}},\"data\":[\""
384389
+ detect_repo + "cat.jpg\"]}";
385390
joutstr = japi.jrender(japi.service_predict(jpredictstr));
391+
jd = JDoc();
386392
std::cout << "joutstr=" << joutstr << std::endl;
387393
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
388394

@@ -392,6 +398,26 @@ TEST(torchapi, service_predict_object_detection)
392398

393399
auto &preds_best = jd["body"]["predictions"][0]["classes"];
394400
ASSERT_EQ(preds_best.Size(), 3);
401+
402+
// base64
403+
cv::Mat img = cv::imread(detect_repo + "cat.jpg");
404+
std::string b64_str = cv_utils::image_to_base64(img, ".png");
405+
jpredictstr = "{\"service\":\"detectserv\",\"parameters\":{"
406+
"\"input\":{\"height\":224,"
407+
"\"width\":224},\"output\":{\"bbox\":true, "
408+
"\"best_bbox\":3}},\"data\":[\""
409+
+ b64_str + "\"]}";
410+
joutstr = japi.jrender(japi.service_predict(jpredictstr));
411+
jd = JDoc();
412+
std::cout << "joutstr=" << joutstr << std::endl;
413+
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
414+
415+
ASSERT_TRUE(!jd.HasParseError());
416+
ASSERT_EQ(200, jd["status"]["code"]);
417+
ASSERT_TRUE(jd["body"]["predictions"].IsArray());
418+
419+
auto &preds_best_b64 = jd["body"]["predictions"][0]["classes"];
420+
ASSERT_EQ(preds_best_b64.Size(), 3);
395421
}
396422

397423
TEST(torchapi, service_predict_object_detection_any_size)

0 commit comments

Comments
 (0)