Skip to content

Commit dabd996

Browse files
Krishn1412scotts
authored andcommitted
Replacing throw with TORCH_CHECK (#725)
Co-authored-by: Scott Schneider <[email protected]>
1 parent b76fe4a commit dabd996

File tree

4 files changed

+88
-94
lines changed

4 files changed

+88
-94
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@ bool CpuDeviceInterface::DecodedFrameContext::operator!=(
3737
CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
3838
: DeviceInterface(device) {
3939
TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!");
40-
if (device_.type() != torch::kCPU) {
41-
throw std::runtime_error("Unsupported device: " + device_.str());
42-
}
40+
TORCH_CHECK(
41+
device_.type() == torch::kCPU, "Unsupported device: ", device_.str());
4342
}
4443

4544
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
@@ -161,9 +160,10 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
161160
frameOutput.data = outputTensor;
162161
}
163162
} else {
164-
throw std::runtime_error(
165-
"Invalid color conversion library: " +
166-
std::to_string(static_cast<int>(colorConversionLibrary)));
163+
TORCH_CHECK(
164+
false,
165+
"Invalid color conversion library: ",
166+
static_cast<int>(colorConversionLibrary));
167167
}
168168
}
169169

@@ -189,9 +189,8 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
189189
const UniqueAVFrame& avFrame) {
190190
int status = av_buffersrc_write_frame(
191191
filterGraphContext_.sourceContext, avFrame.get());
192-
if (status < AVSUCCESS) {
193-
throw std::runtime_error("Failed to add frame to buffer source context");
194-
}
192+
TORCH_CHECK(
193+
status >= AVSUCCESS, "Failed to add frame to buffer source context");
195194

196195
UniqueAVFrame filteredAVFrame(av_frame_alloc());
197196
status = av_buffersink_get_frame(
@@ -241,11 +240,12 @@ void CpuDeviceInterface::createFilterGraph(
241240
filterArgs.str().c_str(),
242241
nullptr,
243242
filterGraphContext_.filterGraph.get());
244-
if (status < 0) {
245-
throw std::runtime_error(
246-
std::string("Failed to create filter graph: ") + filterArgs.str() +
247-
": " + getFFMPEGErrorStringFromErrorCode(status));
248-
}
243+
TORCH_CHECK(
244+
status >= 0,
245+
"Failed to create filter graph: ",
246+
filterArgs.str(),
247+
": ",
248+
getFFMPEGErrorStringFromErrorCode(status));
249249

250250
status = avfilter_graph_create_filter(
251251
&filterGraphContext_.sinkContext,
@@ -254,11 +254,10 @@ void CpuDeviceInterface::createFilterGraph(
254254
nullptr,
255255
nullptr,
256256
filterGraphContext_.filterGraph.get());
257-
if (status < 0) {
258-
throw std::runtime_error(
259-
"Failed to create filter graph: " +
260-
getFFMPEGErrorStringFromErrorCode(status));
261-
}
257+
TORCH_CHECK(
258+
status >= 0,
259+
"Failed to create filter graph: ",
260+
getFFMPEGErrorStringFromErrorCode(status));
262261

263262
enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
264263

@@ -268,11 +267,10 @@ void CpuDeviceInterface::createFilterGraph(
268267
pix_fmts,
269268
AV_PIX_FMT_NONE,
270269
AV_OPT_SEARCH_CHILDREN);
271-
if (status < 0) {
272-
throw std::runtime_error(
273-
"Failed to set output pixel formats: " +
274-
getFFMPEGErrorStringFromErrorCode(status));
275-
}
270+
TORCH_CHECK(
271+
status >= 0,
272+
"Failed to set output pixel formats: ",
273+
getFFMPEGErrorStringFromErrorCode(status));
276274

277275
UniqueAVFilterInOut outputs(avfilter_inout_alloc());
278276
UniqueAVFilterInOut inputs(avfilter_inout_alloc());
@@ -301,19 +299,17 @@ void CpuDeviceInterface::createFilterGraph(
301299
nullptr);
302300
outputs.reset(outputsTmp);
303301
inputs.reset(inputsTmp);
304-
if (status < 0) {
305-
throw std::runtime_error(
306-
"Failed to parse filter description: " +
307-
getFFMPEGErrorStringFromErrorCode(status));
308-
}
302+
TORCH_CHECK(
303+
status >= 0,
304+
"Failed to parse filter description: ",
305+
getFFMPEGErrorStringFromErrorCode(status));
309306

310307
status =
311308
avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr);
312-
if (status < 0) {
313-
throw std::runtime_error(
314-
"Failed to configure filter graph: " +
315-
getFFMPEGErrorStringFromErrorCode(status));
316-
}
309+
TORCH_CHECK(
310+
status >= 0,
311+
"Failed to configure filter graph: ",
312+
getFFMPEGErrorStringFromErrorCode(status));
317313
}
318314

319315
void CpuDeviceInterface::createSwsContext(

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,8 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
166166
CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
167167
: DeviceInterface(device) {
168168
TORCH_CHECK(g_cuda, "CudaDeviceInterface was not registered!");
169-
if (device_.type() != torch::kCUDA) {
170-
throw std::runtime_error("Unsupported device: " + device_.str());
171-
}
169+
TORCH_CHECK(
170+
device_.type() == torch::kCUDA, "Unsupported device: ", device_.str());
172171
}
173172

174173
CudaDeviceInterface::~CudaDeviceInterface() {

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 51 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,10 @@ void SingleStreamDecoder::initializeDecoder() {
103103
// which decodes a few frames to get missing info. For more, see:
104104
// https://ffmpeg.org/doxygen/7.0/group__lavf__decoding.html
105105
int status = avformat_find_stream_info(formatContext_.get(), nullptr);
106-
if (status < 0) {
107-
throw std::runtime_error(
108-
"Failed to find stream info: " +
109-
getFFMPEGErrorStringFromErrorCode(status));
110-
}
106+
TORCH_CHECK(
107+
status >= 0,
108+
"Failed to find stream info: ",
109+
getFFMPEGErrorStringFromErrorCode(status));
111110

112111
for (unsigned int i = 0; i < formatContext_->nb_streams; i++) {
113112
AVStream* avStream = formatContext_->streams[i];
@@ -222,11 +221,10 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
222221
break;
223222
}
224223

225-
if (status != AVSUCCESS) {
226-
throw std::runtime_error(
227-
"Failed to read frame from input file: " +
228-
getFFMPEGErrorStringFromErrorCode(status));
229-
}
224+
TORCH_CHECK(
225+
status == AVSUCCESS,
226+
"Failed to read frame from input file: ",
227+
getFFMPEGErrorStringFromErrorCode(status));
230228

231229
if (packet->flags & AV_PKT_FLAG_DISCARD) {
232230
continue;
@@ -279,11 +277,10 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
279277

280278
// Reset the seek-cursor back to the beginning.
281279
int status = avformat_seek_file(formatContext_.get(), 0, INT64_MIN, 0, 0, 0);
282-
if (status < 0) {
283-
throw std::runtime_error(
284-
"Could not seek file to pts=0: " +
285-
getFFMPEGErrorStringFromErrorCode(status));
286-
}
280+
TORCH_CHECK(
281+
status >= 0,
282+
"Could not seek file to pts=0: ",
283+
getFFMPEGErrorStringFromErrorCode(status));
287284

288285
// Sort all frames by their pts.
289286
for (auto& [streamIndex, streamInfo] : streamInfos_) {
@@ -415,9 +412,7 @@ void SingleStreamDecoder::addStream(
415412
}
416413

417414
retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr);
418-
if (retVal < AVSUCCESS) {
419-
throw std::invalid_argument(getFFMPEGErrorStringFromErrorCode(retVal));
420-
}
415+
TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal));
421416

422417
codecContext->time_base = streamInfo.stream->time_base;
423418
containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName =
@@ -446,11 +441,11 @@ void SingleStreamDecoder::addVideoStream(
446441
auto& streamMetadata =
447442
containerMetadata_.allStreamMetadata[activeStreamIndex_];
448443

449-
if (seekMode_ == SeekMode::approximate &&
450-
!streamMetadata.averageFpsFromHeader.has_value()) {
451-
throw std::runtime_error(
452-
"Seek mode is approximate, but stream " +
453-
std::to_string(activeStreamIndex_) +
444+
if (seekMode_ == SeekMode::approximate) {
445+
TORCH_CHECK(
446+
streamMetadata.averageFpsFromHeader.has_value(),
447+
"Seek mode is approximate, but stream ",
448+
std::to_string(activeStreamIndex_),
454449
" does not have an average fps in its metadata.");
455450
}
456451

@@ -1048,11 +1043,13 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() {
10481043
desiredPts,
10491044
desiredPts,
10501045
0);
1051-
if (status < 0) {
1052-
throw std::runtime_error(
1053-
"Could not seek file to pts=" + std::to_string(desiredPts) + ": " +
1054-
getFFMPEGErrorStringFromErrorCode(status));
1055-
}
1046+
TORCH_CHECK(
1047+
status >= 0,
1048+
"Could not seek file to pts=",
1049+
std::to_string(desiredPts),
1050+
": ",
1051+
getFFMPEGErrorStringFromErrorCode(status));
1052+
10561053
decodeStats_.numFlushes++;
10571054
avcodec_flush_buffers(streamInfo.codecContext.get());
10581055
}
@@ -1121,21 +1118,20 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
11211118
status = avcodec_send_packet(
11221119
streamInfo.codecContext.get(),
11231120
/*avpkt=*/nullptr);
1124-
if (status < AVSUCCESS) {
1125-
throw std::runtime_error(
1126-
"Could not flush decoder: " +
1127-
getFFMPEGErrorStringFromErrorCode(status));
1128-
}
1121+
TORCH_CHECK(
1122+
status >= AVSUCCESS,
1123+
"Could not flush decoder: ",
1124+
getFFMPEGErrorStringFromErrorCode(status));
11291125

11301126
reachedEOF = true;
11311127
break;
11321128
}
11331129

1134-
if (status < AVSUCCESS) {
1135-
throw std::runtime_error(
1136-
"Could not read frame from input file: " +
1137-
getFFMPEGErrorStringFromErrorCode(status));
1138-
}
1130+
TORCH_CHECK(
1131+
status >= AVSUCCESS,
1132+
"Could not read frame from input file: ",
1133+
getFFMPEGErrorStringFromErrorCode(status));
1134+
11391135
} while (packet->stream_index != activeStreamIndex_);
11401136

11411137
if (reachedEOF) {
@@ -1147,11 +1143,10 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
11471143
// We got a valid packet. Send it to the decoder, and we'll receive it in
11481144
// the next iteration.
11491145
status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get());
1150-
if (status < AVSUCCESS) {
1151-
throw std::runtime_error(
1152-
"Could not push packet to decoder: " +
1153-
getFFMPEGErrorStringFromErrorCode(status));
1154-
}
1146+
TORCH_CHECK(
1147+
status >= AVSUCCESS,
1148+
"Could not push packet to decoder: ",
1149+
getFFMPEGErrorStringFromErrorCode(status));
11551150

11561151
decodeStats_.numPacketsSentToDecoder++;
11571152
}
@@ -1162,8 +1157,9 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
11621157
"Requested next frame while there are no more frames left to "
11631158
"decode.");
11641159
}
1165-
throw std::runtime_error(
1166-
"Could not receive frame from decoder: " +
1160+
TORCH_CHECK(
1161+
false,
1162+
"Could not receive frame from decoder: ",
11671163
getFFMPEGErrorStringFromErrorCode(status));
11681164
}
11691165

@@ -1429,7 +1425,7 @@ int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) {
14291425
return std::floor(seconds * streamMetadata.averageFpsFromHeader.value());
14301426
}
14311427
default:
1432-
throw std::runtime_error("Unknown SeekMode");
1428+
TORCH_CHECK(false, "Unknown SeekMode");
14331429
}
14341430
}
14351431

@@ -1456,7 +1452,7 @@ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) {
14561452
return std::ceil(seconds * streamMetadata.averageFpsFromHeader.value());
14571453
}
14581454
default:
1459-
throw std::runtime_error("Unknown SeekMode");
1455+
TORCH_CHECK(false, "Unknown SeekMode");
14601456
}
14611457
}
14621458

@@ -1476,7 +1472,7 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
14761472
streamInfo.timeBase);
14771473
}
14781474
default:
1479-
throw std::runtime_error("Unknown SeekMode");
1475+
TORCH_CHECK(false, "Unknown SeekMode");
14801476
}
14811477
}
14821478

@@ -1493,7 +1489,7 @@ std::optional<int64_t> SingleStreamDecoder::getNumFrames(
14931489
return streamMetadata.numFramesFromHeader;
14941490
}
14951491
default:
1496-
throw std::runtime_error("Unknown SeekMode");
1492+
TORCH_CHECK(false, "Unknown SeekMode");
14971493
}
14981494
}
14991495

@@ -1505,7 +1501,7 @@ double SingleStreamDecoder::getMinSeconds(
15051501
case SeekMode::approximate:
15061502
return 0;
15071503
default:
1508-
throw std::runtime_error("Unknown SeekMode");
1504+
TORCH_CHECK(false, "Unknown SeekMode");
15091505
}
15101506
}
15111507

@@ -1518,7 +1514,7 @@ std::optional<double> SingleStreamDecoder::getMaxSeconds(
15181514
return streamMetadata.durationSecondsFromHeader;
15191515
}
15201516
default:
1521-
throw std::runtime_error("Unknown SeekMode");
1517+
TORCH_CHECK(false, "Unknown SeekMode");
15221518
}
15231519
}
15241520

@@ -1552,10 +1548,10 @@ void SingleStreamDecoder::validateActiveStream(
15521548
}
15531549

15541550
void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) {
1555-
if (!scannedAllStreams_) {
1556-
throw std::runtime_error(
1557-
"Must scan all streams to update metadata before calling " + msg);
1558-
}
1551+
TORCH_CHECK(
1552+
scannedAllStreams_,
1553+
"Must scan all streams to update metadata before calling ",
1554+
msg);
15591555
}
15601556

15611557
void SingleStreamDecoder::validateFrameIndex(

src/torchcodec/_core/custom_ops.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,10 @@ void _add_video_stream(
243243
videoStreamOptions.colorConversionLibrary =
244244
ColorConversionLibrary::SWSCALE;
245245
} else {
246-
throw std::runtime_error(
247-
"Invalid color_conversion_library=" + stdColorConversionLibrary +
246+
TORCH_CHECK(
247+
false,
248+
"Invalid color_conversion_library=",
249+
stdColorConversionLibrary,
248250
". color_conversion_library must be either filtergraph or swscale.");
249251
}
250252
}
@@ -561,6 +563,7 @@ std::string get_stream_json_metadata(
561563
throw std::out_of_range(
562564
"stream_index out of bounds: " + std::to_string(stream_index));
563565
}
566+
564567
auto streamMetadata = allStreamMetadata[stream_index];
565568

566569
std::map<std::string, std::string> map;

0 commit comments

Comments
 (0)