Skip to content

Commit 8a92ab3

Browse files
committed
Pass over decoder impl
1 parent 315811b commit 8a92ab3

File tree

3 files changed

+48
-47
lines changed

3 files changed

+48
-47
lines changed

src/torchcodec/_core/CustomNvdecDeviceInterface.cpp

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,17 @@ static bool g_cuda_custom_nvdec = registerDeviceInterface(
3434
// NVDEC callback functions
3535
static int CUDAAPI
3636
HandleVideoSequence(void* pUserData, CUVIDEOFORMAT* pVideoFormat) {
37-
// printf("Static HandleVideoSequence called\n");
37+
printf(" IN CNI::handleVideoSequence\n");
38+
fflush(stdout);
3839
CustomNvdecDeviceInterface* decoder =
3940
static_cast<CustomNvdecDeviceInterface*>(pUserData);
4041
return decoder->handleVideoSequence(pVideoFormat);
4142
}
4243

4344
static int CUDAAPI
4445
HandlePictureDecode(void* pUserData, CUVIDPICPARAMS* pPicParams) {
45-
// printf("Static HandlePictureDecode called\n");
46+
printf(" IN CNI::handlePictureDecode\n");
47+
fflush(stdout);
4648
CustomNvdecDeviceInterface* decoder =
4749
static_cast<CustomNvdecDeviceInterface*>(pUserData);
4850
return decoder->handlePictureDecode(pPicParams);
@@ -61,7 +63,7 @@ HandlePictureDisplay(void* pUserData, CUVIDPARSERDISPINFO* pDispInfo) {
6163
CustomNvdecDeviceInterface::CustomNvdecDeviceInterface(
6264
const torch::Device& device)
6365
: DeviceInterface(device) {
64-
printf("IN CustomNvdecDeviceInterface::CustomNvdecDeviceInterface\n");
66+
printf(" IN CNI::CustomNvdecDeviceInterface\n");
6567
fflush(stdout);
6668
TORCH_CHECK(
6769
g_cuda_custom_nvdec, "CustomNvdecDeviceInterface was not registered!");
@@ -70,6 +72,8 @@ CustomNvdecDeviceInterface::CustomNvdecDeviceInterface(
7072
}
7173

7274
CustomNvdecDeviceInterface::~CustomNvdecDeviceInterface() {
75+
printf(" IN CNI::destructor\n");
76+
fflush(stdout);
7377
// Clean up any remaining frames in the queue
7478
{
7579
std::lock_guard<std::mutex> lock(frameQueueMutex_);
@@ -96,12 +100,15 @@ CustomNvdecDeviceInterface::~CustomNvdecDeviceInterface() {
96100
videoParser_ = nullptr;
97101
}
98102

99-
isInitialized_ = false;
100-
parserInitialized_ = false;
103+
parserCreated_ = false;
101104
}
102105

103106
std::optional<const AVCodec*> CustomNvdecDeviceInterface::findCodec(
104107
const AVCodecID& codecId) {
108+
109+
// TODONVDEC uhh???
110+
printf(" IN CNI::findCodec\n");
111+
fflush(stdout);
105112
// For custom NVDEC, we bypass FFmpeg codec selection entirely
106113
// We'll handle the codec selection in our own NVDEC initialization
107114
(void)codecId; // Suppress unused parameter warning
@@ -110,29 +117,16 @@ std::optional<const AVCodec*> CustomNvdecDeviceInterface::findCodec(
110117

111118
void CustomNvdecDeviceInterface::initializeContext(
112119
AVCodecContext* codecContext) {
120+
printf(" IN CNI::initializeContext\n");
121+
fflush(stdout);
113122
// Don't set hw_device_ctx - we handle decoding directly with NVDEC SDK
114123
// Just ensure CUDA context exists for PyTorch tensors
115124
torch::Tensor dummyTensor = torch::empty(
116125
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
117126

118-
// Initialize our custom NVDEC decoder
119-
initializeNvdecDecoder(codecContext->codec_id);
120-
121-
// Initialize video parser with the codec ID and extradata
122-
initializeVideoParser(codecContext->codec_id, codecContext->extradata, codecContext->extradata_size);
123-
}
124-
125-
void CustomNvdecDeviceInterface::initializeNvdecDecoder(AVCodecID codecId) {
126-
if (isInitialized_) {
127-
return; // Already initialized
128-
}
129-
130-
// Store the codec ID for later use
131-
currentCodecId_ = codecId;
132-
133-
// Convert AVCodecID to NVDEC codec type
127+
// Convert FFmpeg codec ID to NVDEC codec enum
134128
cudaVideoCodec nvCodec;
135-
switch (codecId) {
129+
switch (codecContext->codec_id) {
136130
case AV_CODEC_ID_H264:
137131
nvCodec = cudaVideoCodec_H264;
138132
break;
@@ -152,11 +146,10 @@ void CustomNvdecDeviceInterface::initializeNvdecDecoder(AVCodecID codecId) {
152146
TORCH_CHECK(
153147
false,
154148
"Unsupported codec for custom NVDEC: ",
155-
avcodec_get_name(codecId));
149+
avcodec_get_name(codecContext->codec_id));
156150
}
157151

158-
// Initialize video format structure (decoder will be created in
159-
// handleVideoSequence)
152+
// TODONVDEC figure out why this is needed and where videoFormat_ is actually used.
160153
memset(&videoFormat_, 0, sizeof(videoFormat_));
161154
videoFormat_.codec = nvCodec;
162155
videoFormat_.coded_width = 0; // Will be set when we get the first frame
@@ -165,15 +158,17 @@ void CustomNvdecDeviceInterface::initializeNvdecDecoder(AVCodecID codecId) {
165158
videoFormat_.bit_depth_luma_minus8 = 0;
166159
videoFormat_.bit_depth_chroma_minus8 = 0;
167160

168-
isInitialized_ = true;
161+
createVideoParser();
169162
}
170163

171-
void CustomNvdecDeviceInterface::initializeVideoParser(AVCodecID codecId, uint8_t* extradata, int extradata_size) {
172-
if (parserInitialized_) {
164+
165+
void CustomNvdecDeviceInterface::createVideoParser() {
166+
printf(" IN CNI::createVideoParser\n");
167+
fflush(stdout);
168+
if (parserCreated_) {
169+
// TODONVDEC - is this needed?
173170
return;
174171
}
175-
176-
// printf("Initializing NVDEC video parser for codec\n");
177172

178173
// Set up video parser parameters
179174
CUVIDPARSERPARAMS parserParams = {};
@@ -186,21 +181,19 @@ void CustomNvdecDeviceInterface::initializeVideoParser(AVCodecID codecId, uint8_
186181
parserParams.pfnSequenceCallback = HandleVideoSequence;
187182
parserParams.pfnDecodePicture = HandlePictureDecode;
188183
parserParams.pfnDisplayPicture = HandlePictureDisplay;
189-
190-
// printf("Parser params: pUserData=%p, pfnSequenceCallback=%p, pfnDecodePicture=%p, pfnDisplayPicture=%p\n",
191-
// parserParams.pUserData, (void*)parserParams.pfnSequenceCallback,
192-
// (void*)parserParams.pfnDecodePicture, (void*)parserParams.pfnDisplayPicture);
193184

194185
CUresult result = cuvidCreateVideoParser(&videoParser_, &parserParams);
195186
TORCH_CHECK(
196187
result == CUDA_SUCCESS, "Failed to create video parser: ", result);
197188

198-
parserInitialized_ = true;
189+
parserCreated_ = true;
199190
}
200191

201192
int CustomNvdecDeviceInterface::handleVideoSequence(
202193
CUVIDEOFORMAT* pVideoFormat) {
203-
// printf("In CustomNvdecDeviceInterface::handleVideoSequence\n");
194+
195+
printf(" IN CNI::handleVideoSequence\n");
196+
fflush(stdout);
204197
TORCH_CHECK(pVideoFormat != nullptr, "Invalid video format");
205198

206199
// Store video format
@@ -250,6 +243,8 @@ int CustomNvdecDeviceInterface::handlePictureDecode(
250243

251244
int CustomNvdecDeviceInterface::handlePictureDisplay(
252245
CUVIDPARSERDISPINFO* pDispInfo) {
246+
printf(" IN CNI::handlePictureDisplay\n");
247+
fflush(stdout);
253248
TORCH_CHECK(pDispInfo != nullptr, "Invalid display info");
254249

255250
// Queue the frame for later retrieval
@@ -279,8 +274,7 @@ int CustomNvdecDeviceInterface::handlePictureDisplay(
279274

280275
UniqueAVFrame CustomNvdecDeviceInterface::decodePacketDirectly(
281276
ReferenceAVPacket& packet) {
282-
TORCH_CHECK(isInitialized_, "NVDEC decoder not initialized");
283-
printf("IN CustomNvdecDeviceInterface::decodePacketDirectly\n");
277+
printf(" IN CNI::decodePacketDirectly\n");
284278
fflush(stdout);
285279

286280
// Extract compressed data from AVPacket
@@ -291,7 +285,7 @@ UniqueAVFrame CustomNvdecDeviceInterface::decodePacketDirectly(
291285
TORCH_CHECK(compressedData != nullptr && size > 0, "Invalid packet data");
292286

293287
// Video parser should already be initialized from initializeContext
294-
TORCH_CHECK(parserInitialized_, "Video parser not initialized");
288+
TORCH_CHECK(parserCreated_, "Video parser not initialized");
295289

296290
// Parse the packet data (now already in Annex B format from bitstream filter)
297291
// printf("About to parse packet: size=%d, pts=%lld\n", size, pts);
@@ -313,6 +307,8 @@ UniqueAVFrame CustomNvdecDeviceInterface::decodePacketDirectly(
313307
std::lock_guard<std::mutex> lock(frameQueueMutex_);
314308
if (frameQueue_.empty()) {
315309
// No frame ready yet (async decoding)
310+
printf(" No frame ready after parsing\n");
311+
fflush(stdout);
316312
return UniqueAVFrame(nullptr);
317313
}
318314

@@ -379,7 +375,7 @@ void CustomNvdecDeviceInterface::convertAVFrameToFrameOutput(
379375
FrameOutput& frameOutput,
380376
std::optional<torch::Tensor> preAllocatedOutputTensor) {
381377

382-
printf("In CNI convertAVFrameToFrameOutput\n");
378+
printf(" In CNI convertAVFrameToFrameOutput\n");
383379
fflush(stdout);
384380

385381
TORCH_CHECK(

src/torchcodec/_core/CustomNvdecDeviceInterface.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,7 @@ class CustomNvdecDeviceInterface : public DeviceInterface {
6565

6666
// Video format info
6767
CUVIDEOFORMAT videoFormat_;
68-
AVCodecID currentCodecId_ = AV_CODEC_ID_NONE;
69-
bool isInitialized_ = false;
70-
bool parserInitialized_ = false;
68+
bool parserCreated_ = false;
7169

7270
// Frame queue for async decoding - stores frame pointer, pitch, and display info
7371
struct FrameData {
@@ -78,11 +76,9 @@ class CustomNvdecDeviceInterface : public DeviceInterface {
7876
std::queue<FrameData> frameQueue_;
7977
std::mutex frameQueueMutex_;
8078

81-
// Custom context initialization for direct NVDEC usage
82-
void initializeNvdecDecoder(AVCodecID codecId);
8379

8480
// Initialize video parser
85-
void initializeVideoParser(AVCodecID codecId, uint8_t* extradata, int extradata_size);
81+
void createVideoParser();
8682

8783
// Convert CUDA frame pointer to AVFrame
8884
UniqueAVFrame convertCudaFrameToAVFrame(

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,9 @@ void SingleStreamDecoder::addStream(
455455
// TODO_CODE_QUALITY same as above.
456456
if (mediaType == AVMEDIA_TYPE_VIDEO) {
457457
if (deviceInterface_) {
458+
// TODONVDEC consider changing the name of this, it's not just about
459+
// ACodecContext initialization anymore, it's more generally about
460+
// initializing the device interface
458461
deviceInterface_->initializeContext(codecContext);
459462
}
460463
}
@@ -1222,6 +1225,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
12221225
ReferenceAVPacket filteredPacket(filteredAutoPacket);
12231226

12241227
// Apply bitstream filtering if needed
1228+
// TODONVDEC see other todos above about BSF logic needing to be more robust.
12251229
if (streamInfo.bitstreamFilter != nullptr) {
12261230
int retVal =
12271231
av_bsf_send_packet(streamInfo.bitstreamFilter.get(), packet.get());
@@ -1241,15 +1245,20 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
12411245
}
12421246

12431247
// Use custom packet decoding (e.g., direct NVDEC)
1248+
printf("SingleStreamDecoder calling CNI::decodePacketDirectly\n");
1249+
fflush(stdout);
12441250
UniqueAVFrame decodedFrame =
12451251
deviceInterface_->decodePacketDirectly(*packetToSend);
1246-
decodeStats_.numPacketsSentToDecoder++;
12471252

12481253
if (decodedFrame && filterFunction(decodedFrame)) {
12491254
// We got the frame we're looking for from direct decoding
1255+
printf("SingleStreamDecoder: we've got the frame!\n");
1256+
fflush(stdout);
12501257
avFrame = std::move(decodedFrame);
12511258
break;
12521259
}
1260+
printf("SingleStreamDecoder: that's not the frame, continuing loop\n");
1261+
fflush(stdout);
12531262
// If custom decoding didn't produce the desired frame, continue the loop
12541263
}
12551264
} else {

0 commit comments

Comments
 (0)