@@ -12,8 +12,12 @@ namespace facebook::torchcodec {
1212
1313namespace {
1414using DeviceInterfaceMap = std::map<torch::DeviceType, CreateDeviceInterfaceFn>;
15- std::mutex g_interface_mutex;
16- std::unique_ptr<DeviceInterfaceMap> g_interface_map;
15+ static std::mutex g_interface_mutex;
16+
17+ DeviceInterfaceMap& getDeviceMap () {
18+ static DeviceInterfaceMap deviceMap;
19+ return deviceMap;
20+ }
1721
1822std::string getDeviceType (const std::string& device) {
1923 size_t pos = device.find (' :' );
@@ -29,35 +33,31 @@ bool registerDeviceInterface(
2933 torch::DeviceType deviceType,
3034 CreateDeviceInterfaceFn createInterface) {
3135 std::scoped_lock lock (g_interface_mutex);
32- if (!g_interface_map) {
33- // We delay this initialization until runtime to avoid the Static
34- // Initialization Order Fiasco:
35- //
36- // https://en.cppreference.com/w/cpp/language/siof
37- g_interface_map = std::make_unique<DeviceInterfaceMap>();
38- }
36+ DeviceInterfaceMap& deviceMap = getDeviceMap ();
37+
3938 TORCH_CHECK (
40- g_interface_map-> find (deviceType) == g_interface_map-> end (),
39+ deviceMap. find (deviceType) == deviceMap. end (),
4140 " Device interface already registered for " ,
4241 deviceType);
43- g_interface_map->insert ({deviceType, createInterface});
42+ deviceMap.insert ({deviceType, createInterface});
43+
4444 return true ;
4545}
4646
4747torch::Device createTorchDevice (const std::string device) {
4848 std::scoped_lock lock (g_interface_mutex);
4949 std::string deviceType = getDeviceType (device);
50+ DeviceInterfaceMap& deviceMap = getDeviceMap ();
51+
5052 auto deviceInterface = std::find_if (
51- g_interface_map-> begin (),
52- g_interface_map-> end (),
53+ deviceMap. begin (),
54+ deviceMap. end (),
5355 [&](const std::pair<torch::DeviceType, CreateDeviceInterfaceFn>& arg) {
5456 return device.rfind (
5557 torch::DeviceTypeName (arg.first , /* lcase*/ true ), 0 ) == 0 ;
5658 });
5759 TORCH_CHECK (
58- deviceInterface != g_interface_map->end (),
59- " Unsupported device: " ,
60- device);
60+ deviceInterface != deviceMap.end (), " Unsupported device: " , device);
6161
6262 return torch::Device (device);
6363}
@@ -66,13 +66,14 @@ std::unique_ptr<DeviceInterface> createDeviceInterface(
6666 const torch::Device& device) {
6767 auto deviceType = device.type ();
6868 std::scoped_lock lock (g_interface_mutex);
69+ DeviceInterfaceMap& deviceMap = getDeviceMap ();
70+
6971 TORCH_CHECK (
70- g_interface_map-> find (deviceType) != g_interface_map-> end (),
72+ deviceMap. find (deviceType) != deviceMap. end (),
7173 " Unsupported device: " ,
7274 device);
7375
74- return std::unique_ptr<DeviceInterface>(
75- (*g_interface_map)[deviceType](device));
76+ return std::unique_ptr<DeviceInterface>(deviceMap[deviceType](device));
7677}
7778
7879} // namespace facebook::torchcodec
0 commit comments