Skip to content

Commit 3378484

Browse files
committed
Improve static initialization
1 parent de7c3c0 commit 3378484

File tree

3 files changed

+24
-21
lines changed

3 files changed

+24
-21
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ extern "C" {
1414
namespace facebook::torchcodec {
1515
namespace {
1616

17-
bool g_cpu = registerDeviceInterface(
17+
static bool g_cpu = registerDeviceInterface(
1818
torch::kCPU,
1919
[](const torch::Device& device) { return new CpuDeviceInterface(device); });
2020

@@ -36,6 +36,7 @@ bool CpuDeviceInterface::DecodedFrameContext::operator!=(
3636

3737
CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
3838
: DeviceInterface(device) {
39+
TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!");
3940
if (device_.type() != torch::kCPU) {
4041
throw std::runtime_error("Unsupported device: " + device_.str());
4142
}

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ extern "C" {
1515
namespace facebook::torchcodec {
1616
namespace {
1717

18-
bool g_cuda =
18+
static bool g_cuda =
1919
registerDeviceInterface(torch::kCUDA, [](const torch::Device& device) {
2020
return new CudaDeviceInterface(device);
2121
});
@@ -165,6 +165,7 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
165165

166166
CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
167167
: DeviceInterface(device) {
168+
TORCH_CHECK(g_cuda, "CudaDeviceInterface was not registered!");
168169
if (device_.type() != torch::kCUDA) {
169170
throw std::runtime_error("Unsupported device: " + device_.str());
170171
}

src/torchcodec/_core/DeviceInterface.cpp

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@ namespace facebook::torchcodec {
1212

1313
namespace {
1414
using DeviceInterfaceMap = std::map<torch::DeviceType, CreateDeviceInterfaceFn>;
15-
std::mutex g_interface_mutex;
16-
std::unique_ptr<DeviceInterfaceMap> g_interface_map;
15+
static std::mutex g_interface_mutex;
16+
17+
DeviceInterfaceMap& getDeviceMap() {
18+
static DeviceInterfaceMap deviceMap;
19+
return deviceMap;
20+
}
1721

1822
std::string getDeviceType(const std::string& device) {
1923
size_t pos = device.find(':');
@@ -29,35 +33,31 @@ bool registerDeviceInterface(
2933
torch::DeviceType deviceType,
3034
CreateDeviceInterfaceFn createInterface) {
3135
std::scoped_lock lock(g_interface_mutex);
32-
if (!g_interface_map) {
33-
// We delay this initialization until runtime to avoid the Static
34-
// Initialization Order Fiasco:
35-
//
36-
// https://en.cppreference.com/w/cpp/language/siof
37-
g_interface_map = std::make_unique<DeviceInterfaceMap>();
38-
}
36+
DeviceInterfaceMap& deviceMap = getDeviceMap();
37+
3938
TORCH_CHECK(
40-
g_interface_map->find(deviceType) == g_interface_map->end(),
39+
deviceMap.find(deviceType) == deviceMap.end(),
4140
"Device interface already registered for ",
4241
deviceType);
43-
g_interface_map->insert({deviceType, createInterface});
42+
deviceMap.insert({deviceType, createInterface});
43+
4444
return true;
4545
}
4646

4747
torch::Device createTorchDevice(const std::string device) {
4848
std::scoped_lock lock(g_interface_mutex);
4949
std::string deviceType = getDeviceType(device);
50+
DeviceInterfaceMap& deviceMap = getDeviceMap();
51+
5052
auto deviceInterface = std::find_if(
51-
g_interface_map->begin(),
52-
g_interface_map->end(),
53+
deviceMap.begin(),
54+
deviceMap.end(),
5355
[&](const std::pair<torch::DeviceType, CreateDeviceInterfaceFn>& arg) {
5456
return device.rfind(
5557
torch::DeviceTypeName(arg.first, /*lcase*/ true), 0) == 0;
5658
});
5759
TORCH_CHECK(
58-
deviceInterface != g_interface_map->end(),
59-
"Unsupported device: ",
60-
device);
60+
deviceInterface != deviceMap.end(), "Unsupported device: ", device);
6161

6262
return torch::Device(device);
6363
}
@@ -66,13 +66,14 @@ std::unique_ptr<DeviceInterface> createDeviceInterface(
6666
const torch::Device& device) {
6767
auto deviceType = device.type();
6868
std::scoped_lock lock(g_interface_mutex);
69+
DeviceInterfaceMap& deviceMap = getDeviceMap();
70+
6971
TORCH_CHECK(
70-
g_interface_map->find(deviceType) != g_interface_map->end(),
72+
deviceMap.find(deviceType) != deviceMap.end(),
7173
"Unsupported device: ",
7274
device);
7375

74-
return std::unique_ptr<DeviceInterface>(
75-
(*g_interface_map)[deviceType](device));
76+
return std::unique_ptr<DeviceInterface>(deviceMap[deviceType](device));
7677
}
7778

7879
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)