99#include < torch/types.h>
1010#include < memory>
1111#include < mutex>
12+ #include " src/torchcodec/_core/CUDACommon.h"
1213
1314namespace facebook ::torchcodec {
1415
@@ -95,30 +96,11 @@ class PerGpuCache {
9596 std::vector<std::unique_ptr<Cache<T, D>>> cache_;
9697};
9798
98- // Note: this function is inline for convenience, not performance. Because the
99- // rest of this file is template functions, they must all be defined in this
100- // header. This function is not a template function, and should, in principle,
101- // be defined in a .cpp file to preserve the One Definition Rule. That's
102- // annoying for such a small amount of code, so we just inline it. If this file
103- // grows, and there are more such functions, we should break them out into a
104- // .cpp file.
105- inline torch::DeviceIndex getNonNegativeDeviceIndex (
106- const torch::Device& device) {
107- torch::DeviceIndex deviceIndex = device.index ();
108- // For single GPU machines libtorch returns -1 for the device index. So for
109- // that case we set the device index to 0. That's used in per-gpu cache
110- // implementation and during initialization of CUDA and FFmpeg contexts
111- // which require non negative indices.
112- deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0 );
113- TORCH_CHECK (deviceIndex >= 0 , " Device index out of range" );
114- return deviceIndex;
115- }
116-
11799template <typename T, typename D>
118100bool PerGpuCache<T, D>::addIfCacheHasCapacity(
119101 const torch::Device& device,
120102 element_type&& obj) {
121- torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex (device);
103+ int deviceIndex = getDeviceIndex (device);
122104 TORCH_CHECK (
123105 static_cast <size_t >(deviceIndex) < cache_.size (),
124106 " Device index out of range" );
@@ -128,7 +110,7 @@ bool PerGpuCache<T, D>::addIfCacheHasCapacity(
128110template <typename T, typename D>
129111typename PerGpuCache<T, D>::element_type PerGpuCache<T, D>::get(
130112 const torch::Device& device) {
131- torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex (device);
113+ int deviceIndex = getDeviceIndex (device);
132114 TORCH_CHECK (
133115 static_cast <size_t >(deviceIndex) < cache_.size (),
134116 " Device index out of range" );
0 commit comments