Skip to content

Commit 8b309c7

Browse files
lzhangzzirexyc
authored andcommitted
wip
1 parent 8e658cd commit 8b309c7

File tree

18 files changed

+1415
-57
lines changed

18 files changed

+1415
-57
lines changed

csrc/mmdeploy/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,8 @@ if (MMDEPLOY_BUILD_SDK)
1818
add_subdirectory(net)
1919
add_subdirectory(codebase)
2020
add_subdirectory(apis)
21+
22+
if (TRITON_MMDEPLOY_BACKEND)
23+
add_subdirectory(triton)
24+
endif ()
2125
endif ()

csrc/mmdeploy/apis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ add_subdirectory(java)
99
if (MMDEPLOY_BUILD_SDK_PYTHON_API)
1010
add_subdirectory(python)
1111
endif ()
12+

csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,28 @@ namespace mmdeploy {
1111

1212
namespace cxx {
1313

14-
class Pipeline : public NonMovable {
14+
class Pipeline : public UniqueHandle<mmdeploy_pipeline_t> {
1515
public:
1616
Pipeline(const Value& config, const Context& context) {
17-
mmdeploy_pipeline_t pipeline{};
18-
auto ec = mmdeploy_pipeline_create_v3((mmdeploy_value_t)&config, context, &pipeline);
17+
auto ec = mmdeploy_pipeline_create_v3((mmdeploy_value_t)&config, context, &handle_);
1918
if (ec != MMDEPLOY_SUCCESS) {
2019
throw_exception(static_cast<ErrorCode>(ec));
2120
}
22-
pipeline_ = pipeline;
2321
}
2422

2523
~Pipeline() {
26-
if (pipeline_) {
27-
mmdeploy_pipeline_destroy(pipeline_);
28-
pipeline_ = nullptr;
24+
if (handle_) {
25+
mmdeploy_pipeline_destroy(handle_);
26+
handle_ = nullptr;
2927
}
3028
}
3129

30+
Pipeline(Pipeline&&) noexcept = default;
31+
Pipeline& operator=(Pipeline&&) noexcept = default;
32+
3233
Value Apply(const Value& inputs) {
3334
mmdeploy_value_t tmp{};
34-
auto ec = mmdeploy_pipeline_apply(pipeline_, (mmdeploy_value_t)&inputs, &tmp);
35+
auto ec = mmdeploy_pipeline_apply(handle_, (mmdeploy_value_t)&inputs, &tmp);
3536
if (ec != MMDEPLOY_SUCCESS) {
3637
throw_exception(static_cast<ErrorCode>(ec));
3738
}
@@ -50,7 +51,7 @@ class Pipeline : public NonMovable {
5051
if (ec != MMDEPLOY_SUCCESS) {
5152
throw_exception(static_cast<ErrorCode>(ec));
5253
}
53-
auto outputs = Apply(*reinterpret_cast<Value*>(inputs));
54+
auto outputs = this->Apply(*reinterpret_cast<Value*>(inputs));
5455
mmdeploy_value_destroy(inputs);
5556

5657
return outputs;
@@ -65,9 +66,6 @@ class Pipeline : public NonMovable {
6566
}
6667
return rets;
6768
}
68-
69-
private:
70-
mmdeploy_pipeline_t pipeline_{};
7169
};
7270

7371
} // namespace cxx

csrc/mmdeploy/core/logger.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ MMDEPLOY_API void SetLogger(spdlog::logger *logger);
4747
#endif
4848

4949
#ifdef SPDLOG_LOGGER_CALL
50-
#define MMDEPLOY_LOG(level, ...) SPDLOG_LOGGER_CALL(mmdeploy::GetLogger(), level, __VA_ARGS__)
50+
#define MMDEPLOY_LOG(level, ...) SPDLOG_LOGGER_CALL(::mmdeploy::GetLogger(), level, __VA_ARGS__)
5151
#else
5252
#define MMDEPLOY_LOG(level, ...) mmdeploy::GetLogger()->log(level, __VA_ARGS__)
5353
#endif

csrc/mmdeploy/core/model.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ struct model_meta_info_t {
3030

3131
struct deploy_meta_info_t {
3232
std::string version;
33+
std::string task;
3334
std::vector<model_meta_info_t> models;
34-
MMDEPLOY_ARCHIVE_MEMBERS(version, models);
35+
MMDEPLOY_ARCHIVE_MEMBERS(version, task, models);
3536
};
3637

3738
class ModelImpl;

csrc/mmdeploy/graph/inference.cpp

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,30 @@ using namespace framework;
1414
InferenceBuilder::InferenceBuilder(Value config) : Builder(std::move(config)) {}
1515

1616
Result<unique_ptr<Node>> InferenceBuilder::BuildImpl() {
17-
auto& model_config = config_["params"]["model"];
18-
Model model;
19-
if (model_config.is_any<Model>()) {
20-
model = model_config.get<Model>();
21-
} else {
22-
auto model_name = model_config.get<string>();
23-
if (auto m = Maybe{config_} / "context" / "model" / model_name / identity<Model>{}) {
24-
model = *m;
17+
Value pipeline_config;
18+
auto context = config_.value("context", Value(ValueType::kObject));
19+
const auto& params = config_["params"];
20+
if (params.contains("model")) {
21+
auto& model_config = params["model"];
22+
Model model;
23+
if (model_config.is_any<Model>()) {
24+
model = model_config.get<Model>();
2525
} else {
26-
model = Model(model_name);
26+
auto model_name = model_config.get<string>();
27+
if (auto m = Maybe{config_} / "context" / "model" / model_name / identity<Model>{}) {
28+
model = *m;
29+
} else {
30+
model = Model(model_name);
31+
}
2732
}
33+
OUTCOME_TRY(pipeline_config, model.ReadConfig("pipeline.json"));
34+
context["model"] = std::move(model);
35+
} else if (params.contains("pipeline")) {
36+
assert(context.contains("model"));
37+
auto model = context["model"].get<Model>();
38+
OUTCOME_TRY(pipeline_config, model.ReadConfig(params["pipeline"].get<std::string>()));
2839
}
2940

30-
OUTCOME_TRY(auto pipeline_config, model.ReadConfig("pipeline.json"));
31-
32-
auto context = config_.value("context", Value(ValueType::kObject));
33-
context["model"] = std::move(model);
34-
3541
if (context.contains("scope")) {
3642
auto name = config_.value("name", config_["type"].get<std::string>());
3743
auto scope = context["scope"].get_ref<profiler::Scope*&>()->CreateScope(name);

csrc/mmdeploy/graph/task.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ Result<unique_ptr<Node>> TaskBuilder::BuildImpl() {
9696
task->is_thread_safe_ = config_.value("is_thread_safe", false);
9797
return std::move(task);
9898
} catch (const std::exception& e) {
99+
MMDEPLOY_ERROR("unhandled exception: {}", e.what());
99100
MMDEPLOY_ERROR("error parsing config: {}", config_);
100101
return nullptr;
101102
}

csrc/mmdeploy/preprocess/transform_module.cpp

Lines changed: 95 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,46 +10,112 @@ namespace mmdeploy {
1010

1111
class TransformModule {
1212
public:
13-
~TransformModule();
14-
TransformModule(TransformModule&&) noexcept;
13+
~TransformModule() = default;
14+
TransformModule(TransformModule&&) noexcept = default;
1515

16-
explicit TransformModule(const Value& args);
17-
Result<Value> operator()(const Value& input);
16+
explicit TransformModule(const Value& args) {
17+
const auto type = "Compose";
18+
auto creator = gRegistry<transform::Transform>().Get(type);
19+
if (!creator) {
20+
MMDEPLOY_ERROR("Unable to find Transform creator: {}. Available transforms: {}", type,
21+
gRegistry<transform::Transform>().List());
22+
throw_exception(eEntryNotFound);
23+
}
24+
auto cfg = args;
25+
if (cfg.contains("device")) {
26+
MMDEPLOY_WARN("force using device: {}", cfg["device"].get<const char*>());
27+
auto device = Device(cfg["device"].get<const char*>());
28+
cfg["context"]["device"] = device;
29+
cfg["context"]["stream"] = Stream::GetDefault(device);
30+
}
31+
transform_ = creator->Create(cfg);
32+
}
33+
34+
Result<Value> operator()(const Value& input) {
35+
auto data = input;
36+
OUTCOME_TRY(transform_->Apply(data));
37+
return data;
38+
}
1839

1940
private:
2041
std::unique_ptr<transform::Transform> transform_;
2142
};
2243

23-
TransformModule::~TransformModule() = default;
44+
MMDEPLOY_REGISTER_FACTORY_FUNC(Module, (Transform, 0), [](const Value& config) {
45+
return CreateTask(TransformModule{config});
46+
});
2447

25-
TransformModule::TransformModule(TransformModule&&) noexcept = default;
48+
#if 0
49+
class Preload {
50+
public:
51+
explicit Preload(const Value& args) {
52+
const auto type = "Compose";
53+
auto creator = gRegistry<transform::Transform>().Get(type);
54+
if (!creator) {
55+
MMDEPLOY_ERROR("Unable to find Transform creator: {}. Available transforms: {}", type,
56+
gRegistry<transform::Transform>().List());
57+
throw_exception(eEntryNotFound);
58+
}
59+
auto cfg = args;
60+
if (cfg.contains("device")) {
61+
MMDEPLOY_WARN("force using device: {}", cfg["device"].get<const char*>());
62+
auto device = Device(cfg["device"].get<const char*>());
63+
cfg["context"]["device"] = device;
64+
cfg["context"]["stream"] = Stream::GetDefault(device);
65+
}
66+
const auto& ctx = cfg["context"];
67+
ctx["device"].get_to(device_);
68+
ctx["stream"].get_to(stream_);
69+
}
2670

27-
TransformModule::TransformModule(const Value& args) {
28-
const auto type = "Compose";
29-
auto creator = gRegistry<transform::Transform>().Get(type);
30-
if (!creator) {
31-
MMDEPLOY_ERROR("Unable to find Transform creator: {}. Available transforms: {}", type,
32-
gRegistry<transform::Transform>().List());
33-
throw_exception(eEntryNotFound);
71+
Result<Value> operator()(const Value& input) {
72+
auto data = input;
73+
if (device_.is_device()) {
74+
bool need_sync = false;
75+
OUTCOME_TRY(Process(data, need_sync));
76+
MMDEPLOY_ERROR("need_sync = {}", need_sync);
77+
MMDEPLOY_ERROR("{}", data);
78+
if (need_sync) {
79+
OUTCOME_TRY(stream_.Wait());
80+
}
81+
}
82+
return data;
3483
}
35-
auto cfg = args;
36-
if (cfg.contains("device")) {
37-
MMDEPLOY_WARN("force using device: {}", cfg["device"].get<const char*>());
38-
auto device = Device(cfg["device"].get<const char*>());
39-
cfg["context"]["device"] = device;
40-
cfg["context"]["stream"] = Stream::GetDefault(device);
84+
85+
Result<void> Process(Value& item, bool& need_sync) {
86+
if (item.is_any<Mat>()) {
87+
auto& mat = item.get_ref<Mat&>();
88+
if (mat.device().is_host()) {
89+
Mat tmp(mat.height(), mat.width(), mat.pixel_format(), mat.type(), device_);
90+
OUTCOME_TRY(stream_.Copy(mat.buffer(), tmp.buffer(), mat.byte_size()));
91+
mat = tmp;
92+
need_sync |= true;
93+
}
94+
} else if (item.is_any<Tensor>()) {
95+
auto& ten = item.get_ref<Tensor&>();
96+
if (ten.device().is_host()) {
97+
TensorDesc desc = ten.desc();
98+
desc.device = device_;
99+
Tensor tmp(desc);
100+
OUTCOME_TRY(stream_.Copy(ten.buffer(), tmp.buffer(), ten.byte_size()));
101+
ten = tmp;
102+
need_sync |= true;
103+
}
104+
} else if (item.is_array() || item.is_object()) {
105+
for (auto& child : item) {
106+
OUTCOME_TRY(Process(child, need_sync));
107+
}
108+
}
109+
return success();
41110
}
42-
transform_ = creator->Create(cfg);
43-
}
44111

45-
Result<Value> TransformModule::operator()(const Value& input) {
46-
auto data = input;
47-
OUTCOME_TRY(transform_->Apply(data));
48-
return data;
49-
}
112+
private:
113+
Device device_;
114+
Stream stream_;
115+
};
50116

51-
MMDEPLOY_REGISTER_FACTORY_FUNC(Module, (Transform, 0), [](const Value& config) {
52-
return CreateTask(TransformModule{config});
53-
});
117+
MMDEPLOY_REGISTER_FACTORY_FUNC(Module, (Preload, 0),
118+
[](const Value& config) { return CreateTask(Preload{config}); });
119+
#endif
54120

55121
} // namespace mmdeploy

0 commit comments

Comments
 (0)