@@ -10,46 +10,112 @@ namespace mmdeploy {
1010
1111class TransformModule {
1212 public:
13- ~TransformModule ();
14- TransformModule (TransformModule&&) noexcept ;
13+ ~TransformModule () = default ;
14+ TransformModule (TransformModule&&) noexcept = default ;
1515
16- explicit TransformModule (const Value& args);
17- Result<Value> operator ()(const Value& input);
16+ explicit TransformModule (const Value& args) {
17+ const auto type = " Compose" ;
18+ auto creator = gRegistry <transform::Transform>().Get (type);
19+ if (!creator) {
20+ MMDEPLOY_ERROR (" Unable to find Transform creator: {}. Available transforms: {}" , type,
21+ gRegistry <transform::Transform>().List ());
22+ throw_exception (eEntryNotFound);
23+ }
24+ auto cfg = args;
25+ if (cfg.contains (" device" )) {
26+ MMDEPLOY_WARN (" force using device: {}" , cfg[" device" ].get <const char *>());
27+ auto device = Device (cfg[" device" ].get <const char *>());
28+ cfg[" context" ][" device" ] = device;
29+ cfg[" context" ][" stream" ] = Stream::GetDefault (device);
30+ }
31+ transform_ = creator->Create (cfg);
32+ }
33+
34+ Result<Value> operator ()(const Value& input) {
35+ auto data = input;
36+ OUTCOME_TRY (transform_->Apply (data));
37+ return data;
38+ }
1839
1940 private:
2041 std::unique_ptr<transform::Transform> transform_;
2142};
2243
23- TransformModule::~TransformModule () = default ;
44+ MMDEPLOY_REGISTER_FACTORY_FUNC (Module, (Transform, 0 ), [](const Value& config) {
45+ return CreateTask (TransformModule{config});
46+ });
2447
25- TransformModule::TransformModule (TransformModule&&) noexcept = default ;
48+ #if 0
49+ class Preload {
50+ public:
51+ explicit Preload(const Value& args) {
52+ const auto type = "Compose";
53+ auto creator = gRegistry<transform::Transform>().Get(type);
54+ if (!creator) {
55+ MMDEPLOY_ERROR("Unable to find Transform creator: {}. Available transforms: {}", type,
56+ gRegistry<transform::Transform>().List());
57+ throw_exception(eEntryNotFound);
58+ }
59+ auto cfg = args;
60+ if (cfg.contains("device")) {
61+ MMDEPLOY_WARN("force using device: {}", cfg["device"].get<const char*>());
62+ auto device = Device(cfg["device"].get<const char*>());
63+ cfg["context"]["device"] = device;
64+ cfg["context"]["stream"] = Stream::GetDefault(device);
65+ }
66+ const auto& ctx = cfg["context"];
67+ ctx["device"].get_to(device_);
68+ ctx["stream"].get_to(stream_);
69+ }
2670
27- TransformModule::TransformModule (const Value& args) {
28- const auto type = " Compose" ;
29- auto creator = gRegistry <transform::Transform>().Get (type);
30- if (!creator) {
31- MMDEPLOY_ERROR (" Unable to find Transform creator: {}. Available transforms: {}" , type,
32- gRegistry <transform::Transform>().List ());
33- throw_exception (eEntryNotFound);
71+ Result<Value> operator()(const Value& input) {
72+ auto data = input;
73+ if (device_.is_device()) {
74+ bool need_sync = false;
75+ OUTCOME_TRY(Process(data, need_sync));
76+ MMDEPLOY_ERROR("need_sync = {}", need_sync);
77+ MMDEPLOY_ERROR("{}", data);
78+ if (need_sync) {
79+ OUTCOME_TRY(stream_.Wait());
80+ }
81+ }
82+ return data;
3483 }
35- auto cfg = args;
36- if (cfg.contains (" device" )) {
37- MMDEPLOY_WARN (" force using device: {}" , cfg[" device" ].get <const char *>());
38- auto device = Device (cfg[" device" ].get <const char *>());
39- cfg[" context" ][" device" ] = device;
40- cfg[" context" ][" stream" ] = Stream::GetDefault (device);
84+
85+ Result<void> Process(Value& item, bool& need_sync) {
86+ if (item.is_any<Mat>()) {
87+ auto& mat = item.get_ref<Mat&>();
88+ if (mat.device().is_host()) {
89+ Mat tmp(mat.height(), mat.width(), mat.pixel_format(), mat.type(), device_);
90+ OUTCOME_TRY(stream_.Copy(mat.buffer(), tmp.buffer(), mat.byte_size()));
91+ mat = tmp;
92+ need_sync |= true;
93+ }
94+ } else if (item.is_any<Tensor>()) {
95+ auto& ten = item.get_ref<Tensor&>();
96+ if (ten.device().is_host()) {
97+ TensorDesc desc = ten.desc();
98+ desc.device = device_;
99+ Tensor tmp(desc);
100+ OUTCOME_TRY(stream_.Copy(ten.buffer(), tmp.buffer(), ten.byte_size()));
101+ ten = tmp;
102+ need_sync |= true;
103+ }
104+ } else if (item.is_array() || item.is_object()) {
105+ for (auto& child : item) {
106+ OUTCOME_TRY(Process(child, need_sync));
107+ }
108+ }
109+ return success();
41110 }
42- transform_ = creator->Create (cfg);
43- }
44111
45- Result<Value> TransformModule::operator ()(const Value& input) {
46- auto data = input;
47- OUTCOME_TRY (transform_->Apply (data));
48- return data;
49- }
112+ private:
113+ Device device_;
114+ Stream stream_;
115+ };
50116
51- MMDEPLOY_REGISTER_FACTORY_FUNC (Module, (Transform , 0 ), []( const Value& config) {
52- return CreateTask (TransformModule {config});
53- });
117+ MMDEPLOY_REGISTER_FACTORY_FUNC(Module, (Preload , 0),
118+ [](const Value& config) { return CreateTask(Preload {config}); });
119+ # endif
54120
55121} // namespace mmdeploy
0 commit comments