77#include < functional>
88#include < memory>
99#include < optional>
10+ #include < ostream>
1011#include < string>
1112#include < vector>
1213
1314#include " mmdeploy/core/macro.h"
1415#include " mmdeploy/core/mpl/type_traits.h"
1516#include " mmdeploy/core/status_code.h"
17+ #include " mmdeploy/core/utils/formatter.h"
1618
1719namespace mmdeploy {
1820
@@ -97,6 +99,11 @@ class Device {
9799 return PlatformId (platform_id_);
98100 }
99101
102+ friend std::ostream& operator <<(std::ostream& os, const Device& device) {
103+ os << " (" << device.platform_id_ << " , " << device.device_id_ << " )" ;
104+ return os;
105+ }
106+
100107 private:
101108 int platform_id_{0 };
102109 int device_id_{0 };
@@ -112,6 +119,9 @@ class MMDEPLOY_API Platform {
112119 // throws if not found
113120 explicit Platform (int platform_id);
114121
122+ // bind device with the current thread
123+ Result<void > Bind (Device device, Device* prev);
124+
115125 // -1 if invalid
116126 int GetPlatformId () const ;
117127
@@ -135,6 +145,27 @@ class MMDEPLOY_API Platform {
135145
136146MMDEPLOY_API const char * GetPlatformName (PlatformId id);
137147
148+ class DeviceGuard {
149+ public:
150+ explicit DeviceGuard (Device device) : platform_(device.platform_id()) {
151+ auto r = platform_.Bind (device, &prev_);
152+ if (!r) {
153+ MMDEPLOY_ERROR (" failed to bind device {}: {}" , device, r.error ().message ().c_str ());
154+ }
155+ }
156+
157+ ~DeviceGuard () {
158+ auto r = platform_.Bind (prev_, nullptr );
159+ if (!r) {
160+ MMDEPLOY_ERROR (" failed to unbind device {}: {}" , prev_, r.error ().message ().c_str ());
161+ }
162+ }
163+
164+ private:
165+ Platform platform_;
166+ Device prev_;
167+ };
168+
138169class MMDEPLOY_API Stream {
139170 public:
140171 Stream () = default ;
0 commit comments