Skip to content

Commit 89ada0a

Browse files
[video] make NVDec cache size adjustable (meta-pytorch#1246)
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
1 parent 900979e commit 89ada0a

File tree

12 files changed

+311
-22
lines changed

12 files changed

+311
-22
lines changed

docs/source/api_ref_decoders.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ For an audio decoder tutorial, see: :ref:`sphx_glr_generated_examples_decoding_a
2525
:template: function.rst
2626

2727
set_cuda_backend
28+
set_nvdec_cache_capacity
29+
get_nvdec_cache_capacity
2830

2931
.. autosummary::
3032
:toctree: generated/

src/torchcodec/_core/CMakeLists.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ function(make_torchcodec_libraries
137137
Transform.cpp
138138
Metadata.cpp
139139
SwScale.cpp
140+
NVDECCacheConfig.cpp
140141
)
141142

142143
if(ENABLE_CUDA)
@@ -163,9 +164,10 @@ function(make_torchcodec_libraries
163164
)
164165

165166
if(ENABLE_CUDA)
166-
# We have to define USE_CUDA because we rely on some APIs like
167-
# aoti_torch_get_current_cuda_stream, which are only exposed in torch
168-
# headers if is defined!
167+
# We define USE_CUDA to guard CUDA-specific code paths (e.g.
168+
# NVDECCache usage in NVDECCacheConfig.cpp) and because some torch
169+
# APIs like aoti_torch_get_current_cuda_stream are only exposed when
170+
# USE_CUDA is defined.
169171
# https://github.com/pytorch/pytorch/blob/98e36864e640023a716e058d894ea2d20e76e5f7/torch/csrc/inductor/aoti_torch/c/shim.h#L573-L602
170172
target_compile_definitions(${core_library_name} PRIVATE USE_CUDA)
171173
endif()

src/torchcodec/_core/NVDECCache.cpp

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "CUDACommon.h"
1010
#include "FFMPEGCommon.h"
1111
#include "NVDECCache.h"
12+
#include "NVDECCacheConfig.h"
1213

1314
#include <cuda_runtime.h> // For cudaGetDevice
1415

@@ -19,9 +20,13 @@ extern "C" {
1920

2021
namespace facebook::torchcodec {
2122

22-
NVDECCache& NVDECCache::getCache(const StableDevice& device) {
23+
NVDECCache* NVDECCache::getCacheInstances() {
2324
static NVDECCache cacheInstances[MAX_CUDA_GPUS];
24-
return cacheInstances[getDeviceIndex(device)];
25+
return cacheInstances;
26+
}
27+
28+
NVDECCache& NVDECCache::getCache(const StableDevice& device) {
29+
return getCacheInstances()[getDeviceIndex(device)];
2530
}
2631

2732
UniqueCUvideodecoder NVDECCache::getDecoder(CUVIDEOFORMAT* videoFormat) {
@@ -39,6 +44,21 @@ UniqueCUvideodecoder NVDECCache::getDecoder(CUVIDEOFORMAT* videoFormat) {
3944
return nullptr;
4045
}
4146

47+
// Evicts the least-recently-used entry from cache_.
48+
// Caller must hold cacheLock_!!!
49+
void NVDECCache::evictLRUEntry() {
50+
if (cache_.empty()) {
51+
return;
52+
}
53+
auto victim = cache_.begin();
54+
for (auto it = cache_.begin(); it != cache_.end(); ++it) {
55+
if (it->second.lastUsed < victim->second.lastUsed) {
56+
victim = it;
57+
}
58+
}
59+
cache_.erase(victim);
60+
}
61+
4262
void NVDECCache::returnDecoder(
4363
CUVIDEOFORMAT* videoFormat,
4464
UniqueCUvideodecoder decoder) {
@@ -47,25 +67,40 @@ void NVDECCache::returnDecoder(
4767
CacheKey key(videoFormat);
4868
std::lock_guard<std::mutex> lock(cacheLock_);
4969

50-
// Evict least recently used entry if at capacity.
51-
// This search is O(MAX_CACHE_SIZE) but MAX_CACHE_SIZE is always small, so
52-
// this isn't significant.
53-
if (cache_.size() >= MAX_CACHE_SIZE) {
54-
auto victim = cache_.begin();
55-
for (auto it = cache_.begin(); it != cache_.end(); ++it) {
56-
if (it->second.lastUsed < victim->second.lastUsed) {
57-
victim = it;
58-
}
59-
}
60-
cache_.erase(victim);
70+
int capacity = getNVDECCacheCapacity();
71+
if (capacity <= 0) {
72+
return;
73+
}
74+
75+
// Evict least recently used entries until under capacity.
76+
// This search is O(capacity), which is supposed to be small,
77+
// so linear vs constant search overhead is expected to be negligible.
78+
while (cache_.size() >= static_cast<size_t>(capacity)) {
79+
evictLRUEntry();
6180
}
6281

6382
// Add the decoder back to cache
6483
cache_.emplace(key, CacheEntry(std::move(decoder), lastUsedCounter_++));
6584

6685
STD_TORCH_CHECK(
67-
cache_.size() <= MAX_CACHE_SIZE,
68-
"Cache size exceeded maximum limit, please report a bug");
86+
cache_.size() <= static_cast<size_t>(capacity),
87+
"Cache size exceeded capacity, please report a bug");
88+
}
89+
90+
void NVDECCache::evictExcessEntriesAcrossDevices(int capacity) {
91+
NVDECCache* instances = getCacheInstances();
92+
for (int i = 0; i < MAX_CUDA_GPUS; ++i) {
93+
std::lock_guard<std::mutex> lock(instances[i].cacheLock_);
94+
while (instances[i].cache_.size() > static_cast<size_t>(capacity)) {
95+
instances[i].evictLRUEntry();
96+
}
97+
}
98+
}
99+
100+
int NVDECCache::getCacheSizeForDevice(int device_index) {
101+
NVDECCache* instances = getCacheInstances();
102+
std::lock_guard<std::mutex> lock(instances[device_index].cacheLock_);
103+
return static_cast<int>(instances[device_index].cache_.size());
69104
}
70105

71106
} // namespace facebook::torchcodec

src/torchcodec/_core/NVDECCache.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <cuda.h>
1414

1515
#include "NVCUVIDRuntimeLoader.h"
16+
#include "NVDECCacheConfig.h"
1617
#include "StableABICompat.h"
1718
#include "nvcuvid_include/cuviddec.h"
1819
#include "nvcuvid_include/nvcuvid.h"
@@ -56,6 +57,13 @@ class NVDECCache {
5657
// Return decoder to cache using LRU eviction.
5758
void returnDecoder(CUVIDEOFORMAT* videoFormat, UniqueCUvideodecoder decoder);
5859

60+
// Iterates all per-device cache instances and evicts LRU entries until each
61+
// cache's size is at most capacity. Called from setNVDECCacheCapacity().
62+
static void evictExcessEntriesAcrossDevices(int capacity);
63+
64+
// Returns the number of entries in the cache for a given device index.
65+
static int getCacheSizeForDevice(int device_index);
66+
5967
private:
6068
// Cache key struct: a decoder can be reused and taken from the cache only if
6169
// all these parameters match.
@@ -103,12 +111,13 @@ class NVDECCache {
103111
NVDECCache() = default;
104112
~NVDECCache() = default;
105113

114+
void evictLRUEntry();
115+
116+
static NVDECCache* getCacheInstances();
117+
106118
std::multimap<CacheKey, CacheEntry> cache_;
107119
std::mutex cacheLock_;
108120
uint64_t lastUsedCounter_ = 0;
109-
110-
// Max number of cached decoders, per device
111-
static constexpr int MAX_CACHE_SIZE = 20;
112121
};
113122

114123
} // namespace facebook::torchcodec
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include "NVDECCacheConfig.h"
8+
9+
#include <atomic>
10+
#include <mutex>
11+
12+
#include "c10/util/Exception.h"
13+
14+
#ifdef USE_CUDA
15+
#include "CUDACommon.h"
16+
#include "NVDECCache.h"
17+
#endif
18+
19+
namespace facebook::torchcodec {
20+
21+
static std::atomic<int> g_nvdecCacheCapacity{DEFAULT_NVDEC_CACHE_CAPACITY};
22+
// This mutex serializes setNVDECCacheCapacity() calls so that the atomic store
23+
// and the subsequent cache eviction happen as one unit. getNVDECCacheCapacity()
24+
// intentionally reads the atomic without this mutex: callers like
25+
// returnDecoder() may briefly see a stale value during an ongoing
26+
// setNVDECCacheCapacity(), which is acceptable because the worst case is a
27+
// single decoder being added back to the cache after eviction. That entry will
28+
// be consumed by a subsequent getDecoder() call or evicted by a future
29+
// returnDecoder() or setNVDECCacheCapacity() call.
30+
static std::mutex g_nvdecCacheCapacityMutex;
31+
32+
void setNVDECCacheCapacity(int capacity) {
33+
TORCH_CHECK(
34+
capacity >= 0,
35+
"NVDEC cache capacity must be non-negative, got ",
36+
capacity);
37+
std::lock_guard<std::mutex> lock(g_nvdecCacheCapacityMutex);
38+
g_nvdecCacheCapacity.store(capacity);
39+
#ifdef USE_CUDA
40+
NVDECCache::evictExcessEntriesAcrossDevices(capacity);
41+
#endif
42+
}
43+
44+
int getNVDECCacheCapacity() {
45+
return g_nvdecCacheCapacity.load();
46+
}
47+
48+
int getNVDECCacheSize([[maybe_unused]] int device_index) {
49+
#ifdef USE_CUDA
50+
TORCH_CHECK(
51+
device_index >= 0 && device_index < MAX_CUDA_GPUS,
52+
"device_index must be between 0 and ",
53+
MAX_CUDA_GPUS - 1,
54+
", got ",
55+
device_index);
56+
return NVDECCache::getCacheSizeForDevice(device_index);
57+
#else
58+
return 0;
59+
#endif
60+
}
61+
62+
} // namespace facebook::torchcodec
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
// This header is intentionally CUDA-free so it can be included from
10+
// custom_ops.cpp which is compiled without CUDA headers.
11+
12+
namespace facebook::torchcodec {
13+
14+
// Default capacity of the per-device NVDEC decoder cache.
15+
// capacity == maximum number of cached instances allowed.
16+
constexpr int DEFAULT_NVDEC_CACHE_CAPACITY = 20;
17+
18+
// Set the capacity of the per-device NVDEC decoder cache.
19+
// capacity must be non-negative.
20+
void setNVDECCacheCapacity(int capacity);
21+
22+
// Get the current capacity of the per-device NVDEC decoder cache.
23+
int getNVDECCacheCapacity();
24+
25+
// Get the current number of entries in the NVDEC decoder cache for a device.
26+
// This is currently only used for tests, and not publicly exposed.
27+
// TODO expose it?
28+
int getNVDECCacheSize(int device_index);
29+
30+
} // namespace facebook::torchcodec

src/torchcodec/_core/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_add_video_stream,
1818
_get_backend_details,
1919
_get_key_frame_indices,
20+
_get_nvdec_cache_size,
2021
_test_frame_pts_equality,
2122
add_audio_stream,
2223
add_video_stream,
@@ -42,6 +43,8 @@
4243
get_frames_in_range,
4344
get_json_metadata,
4445
get_next_frame,
46+
get_nvdec_cache_capacity,
4547
scan_all_streams_to_update_metadata,
4648
seek_to_pts,
49+
set_nvdec_cache_capacity,
4750
)

src/torchcodec/_core/custom_ops.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "AVIOFileLikeContext.h"
1313
#include "AVIOTensorContext.h"
1414
#include "Encoder.h"
15+
#include "NVDECCacheConfig.h"
1516
#include "SingleStreamDecoder.h"
1617
#include "StableABICompat.h"
1718
#include "ValidationUtils.h"
@@ -76,6 +77,9 @@ STABLE_TORCH_LIBRARY(torchcodec_ns, m) {
7677
m.def(
7778
"_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool");
7879
m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()");
80+
m.def("set_nvdec_cache_capacity(int capacity) -> ()");
81+
m.def("get_nvdec_cache_capacity() -> int");
82+
m.def("_get_nvdec_cache_size(int device_index) -> int");
7983
}
8084

8185
namespace {
@@ -1085,6 +1089,28 @@ void scan_all_streams_to_update_metadata(torch::stable::Tensor& decoder) {
10851089
videoDecoder->scanFileAndUpdateMetadataAndIndex();
10861090
}
10871091

1092+
void set_nvdec_cache_capacity(int64_t capacity) {
1093+
int capacityInt = validateInt64ToInt(capacity, "capacity");
1094+
STD_TORCH_CHECK(
1095+
capacityInt >= 0,
1096+
"NVDEC cache capacity must be non-negative, got ",
1097+
capacityInt);
1098+
setNVDECCacheCapacity(capacityInt);
1099+
}
1100+
1101+
int64_t get_nvdec_cache_capacity() {
1102+
return static_cast<int64_t>(getNVDECCacheCapacity());
1103+
}
1104+
1105+
int64_t _get_nvdec_cache_size(int64_t device_index) {
1106+
int deviceIndexInt = validateInt64ToInt(device_index, "device_index");
1107+
STD_TORCH_CHECK(
1108+
deviceIndexInt >= 0,
1109+
"device_index must be non-negative, got ",
1110+
deviceIndexInt);
1111+
return static_cast<int64_t>(getNVDECCacheSize(deviceIndexInt));
1112+
}
1113+
10881114
STABLE_TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
10891115
m.impl("create_from_file", TORCH_BOX(&create_from_file));
10901116
m.impl("create_from_tensor", TORCH_BOX(&create_from_tensor));
@@ -1095,6 +1121,9 @@ STABLE_TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
10951121
m.impl("encode_video_to_file", TORCH_BOX(&encode_video_to_file));
10961122
m.impl("encode_video_to_tensor", TORCH_BOX(&encode_video_to_tensor));
10971123
m.impl("_encode_video_to_file_like", TORCH_BOX(&_encode_video_to_file_like));
1124+
m.impl("set_nvdec_cache_capacity", TORCH_BOX(&set_nvdec_cache_capacity));
1125+
m.impl("get_nvdec_cache_capacity", TORCH_BOX(&get_nvdec_cache_capacity));
1126+
m.impl("_get_nvdec_cache_size", TORCH_BOX(&_get_nvdec_cache_size));
10981127
}
10991128

11001129
STABLE_TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {

src/torchcodec/_core/ops.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ def add_video_stream(
136136
torch.ops.torchcodec_ns._get_json_ffmpeg_library_versions.default
137137
)
138138
_get_backend_details = torch.ops.torchcodec_ns._get_backend_details.default
139+
set_nvdec_cache_capacity = torch.ops.torchcodec_ns.set_nvdec_cache_capacity.default
140+
get_nvdec_cache_capacity = torch.ops.torchcodec_ns.get_nvdec_cache_capacity.default
141+
_get_nvdec_cache_size = torch.ops.torchcodec_ns._get_nvdec_cache_size.default
139142

140143

141144
# =============================
@@ -572,3 +575,18 @@ def get_ffmpeg_library_versions():
572575
@register_fake("torchcodec_ns::_get_backend_details")
573576
def _get_backend_details_abstract(decoder: torch.Tensor) -> str:
574577
return ""
578+
579+
580+
@register_fake("torchcodec_ns::set_nvdec_cache_capacity")
581+
def set_nvdec_cache_capacity_abstract(capacity: int) -> None:
582+
return
583+
584+
585+
@register_fake("torchcodec_ns::get_nvdec_cache_capacity")
586+
def get_nvdec_cache_capacity_abstract() -> int:
587+
return 0
588+
589+
590+
@register_fake("torchcodec_ns::_get_nvdec_cache_size")
591+
def _get_nvdec_cache_size_abstract(device_index: int) -> int:
592+
return 0

src/torchcodec/decoders/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
from .._core import AudioStreamMetadata, VideoStreamMetadata
88
from ._audio_decoder import AudioDecoder # noqa
9-
from ._decoder_utils import set_cuda_backend # noqa
9+
from ._decoder_utils import ( # noqa
10+
get_nvdec_cache_capacity,
11+
set_cuda_backend,
12+
set_nvdec_cache_capacity,
13+
)
1014
from ._video_decoder import CpuFallbackStatus, VideoDecoder # noqa
1115

1216
SimpleVideoDecoder = VideoDecoder

0 commit comments

Comments
 (0)