Skip to content

Commit 3494331

Browse files
author
Molly Xu
committed
modified fallback logic
1 parent 1478117 commit 3494331

File tree

1 file changed

+32
-19
lines changed

1 file changed

+32
-19
lines changed

src/torchcodec/_core/Metadata.cpp

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55
// LICENSE file in the root directory of this source tree.
66

77
#include "Metadata.h"
8+
#include "torch/types.h"
89

910
namespace facebook::torchcodec {
1011

1112
std::optional<double> StreamMetadata::getDurationSeconds(
1213
SeekMode seekMode) const {
1314
switch (seekMode) {
14-
case SeekMode::custom_frame_mappings:
1515
case SeekMode::exact:
16-
// In exact mode, use the scanned content value
16+
return endStreamPtsSecondsFromContent.value() -
17+
beginStreamPtsSecondsFromContent.value();
18+
case SeekMode::custom_frame_mappings:
1719
if (endStreamPtsSecondsFromContent.has_value() &&
1820
beginStreamPtsSecondsFromContent.has_value()) {
1921
return endStreamPtsSecondsFromContent.value() -
@@ -30,48 +32,51 @@ std::optional<double> StreamMetadata::getDurationSeconds(
3032
averageFpsFromHeader.value();
3133
}
3234
return std::nullopt;
35+
default:
36+
TORCH_CHECK(false, "Unknown SeekMode");
3337
}
34-
return std::nullopt;
3538
}
3639

3740
double StreamMetadata::getBeginStreamSeconds(SeekMode seekMode) const {
3841
switch (seekMode) {
39-
case SeekMode::custom_frame_mappings:
4042
case SeekMode::exact:
43+
return beginStreamPtsSecondsFromContent.value();
44+
case SeekMode::custom_frame_mappings:
45+
case SeekMode::approximate:
4146
if (beginStreamPtsSecondsFromContent.has_value()) {
4247
return beginStreamPtsSecondsFromContent.value();
4348
}
4449
return 0.0;
45-
case SeekMode::approximate:
46-
return 0.0;
50+
default:
51+
TORCH_CHECK(false, "Unknown SeekMode");
4752
}
48-
return 0.0;
4953
}
5054

5155
std::optional<double> StreamMetadata::getEndStreamSeconds(
5256
SeekMode seekMode) const {
5357
switch (seekMode) {
54-
case SeekMode::custom_frame_mappings:
5558
case SeekMode::exact:
59+
return endStreamPtsSecondsFromContent.value();
60+
case SeekMode::custom_frame_mappings:
61+
case SeekMode::approximate:
5662
if (endStreamPtsSecondsFromContent.has_value()) {
5763
return endStreamPtsSecondsFromContent.value();
5864
}
5965
return getDurationSeconds(seekMode);
60-
case SeekMode::approximate:
61-
return getDurationSeconds(seekMode);
66+
default:
67+
TORCH_CHECK(false, "Unknown SeekMode");
6268
}
63-
return std::nullopt;
6469
}
6570

6671
std::optional<int64_t> StreamMetadata::getNumFrames(SeekMode seekMode) const {
6772
switch (seekMode) {
68-
case SeekMode::custom_frame_mappings:
6973
case SeekMode::exact:
74+
return numFramesFromContent.value();
75+
case SeekMode::custom_frame_mappings:
76+
case SeekMode::approximate: {
7077
if (numFramesFromContent.has_value()) {
7178
return numFramesFromContent.value();
7279
}
73-
return std::nullopt;
74-
case SeekMode::approximate: {
7580
if (numFramesFromHeader.has_value()) {
7681
return numFramesFromHeader.value();
7782
}
@@ -82,14 +87,23 @@ std::optional<int64_t> StreamMetadata::getNumFrames(SeekMode seekMode) const {
8287
}
8388
return std::nullopt;
8489
}
90+
default:
91+
TORCH_CHECK(false, "Unknown SeekMode");
8592
}
86-
return std::nullopt;
8793
}
8894

8995
std::optional<double> StreamMetadata::getAverageFps(SeekMode seekMode) const {
9096
switch (seekMode) {
91-
case SeekMode::custom_frame_mappings:
9297
case SeekMode::exact:
98+
if (endStreamPtsSecondsFromContent.value() !=
99+
beginStreamPtsSecondsFromContent.value()) {
100+
return static_cast<double>(
101+
getNumFrames(seekMode).value() /
102+
(endStreamPtsSecondsFromContent.value() -
103+
beginStreamPtsSecondsFromContent.value()));
104+
}
105+
case SeekMode::custom_frame_mappings:
106+
case SeekMode::approximate:
93107
if (getNumFrames(seekMode).has_value() &&
94108
beginStreamPtsSecondsFromContent.has_value() &&
95109
endStreamPtsSecondsFromContent.has_value() &&
@@ -101,10 +115,9 @@ std::optional<double> StreamMetadata::getAverageFps(SeekMode seekMode) const {
101115
beginStreamPtsSecondsFromContent.value()));
102116
}
103117
return averageFpsFromHeader;
104-
case SeekMode::approximate:
105-
return averageFpsFromHeader;
118+
default:
119+
TORCH_CHECK(false, "Unknown SeekMode");
106120
}
107-
return std::nullopt;
108121
}
109122

110123
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)