@@ -783,42 +783,20 @@ class dev_mgr {
783783 list_devices();
784784#endif
785785 }
786- void check_id(unsigned int & id) const {
786+ void check_id(unsigned int id) const {
787787 std::lock_guard<std::recursive_mutex> lock(m_mutex);
788788 if (id >= _devs.size()) {
789- if (!_stack_addr.count(&_dev_stack)) {
790- id = DEFAULT_DEVICE_ID;
791- return;
792- }
793789 throw std::runtime_error("invalid device id");
794790 }
795791 }
796-
797- class stack_wrapper : public std::stack<unsigned int> {
798- public:
799- stack_wrapper() {
800- std::lock_guard<std::recursive_mutex> lock(instance().m_mutex);
801- _stack_addr.insert(this);
802- }
803- ~stack_wrapper() {
804- std::lock_guard<std::recursive_mutex> lock(instance().m_mutex);
805- _stack_addr.erase(this);
806- }
807- };
808-
809792 std::vector<std::shared_ptr<device_ext>> _devs;
810793 /// stack of devices resulting from CUDA context change;
811- inline static thread_local stack_wrapper _dev_stack;
794+ static inline thread_local std::stack<unsigned int> _dev_stack;
812795 /// DEFAULT_DEVICE_ID is used, if current_device_id() finds an empty
813796 /// _dev_stack, which means the default device should be used for the current
814797 /// thread.
815798 const unsigned int DEFAULT_DEVICE_ID = 0;
816799 int _cpu_device = -1;
817- // Add address when constructing _dev_stack, and remove it when destructing.
818- // It can be used to check if _dev_stack is destructed to avoid getting
819- // garbage data after _dev_stack is destroyed when destructing global static
820- // variables.
821- inline static std::set<stack_wrapper *> _stack_addr;
822800};
823801
824802/// Util function to get the default queue of current selected device depends on
0 commit comments