Skip to content

Commit 691dde7

Browse files
committed
Use 'status' instead of ffmpegRet
1 parent 01dc1b1 commit 691dde7

File tree

2 files changed

+62
-45
lines changed

2 files changed

+62
-45
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,20 @@ Encoder::~Encoder() {}
1313
Encoder::Encoder(int sampleRate, std::string_view fileName)
1414
: sampleRate_(sampleRate) {
1515
AVFormatContext* avFormatContext = nullptr;
16-
avformat_alloc_output_context2(&avFormatContext, nullptr, nullptr, fileName.data());
16+
avformat_alloc_output_context2(
17+
&avFormatContext, nullptr, nullptr, fileName.data());
1718
TORCH_CHECK(avFormatContext != nullptr, "Couldn't allocate AVFormatContext.");
1819
avFormatContext_.reset(avFormatContext);
1920

2021
TORCH_CHECK(
2122
!(avFormatContext->oformat->flags & AVFMT_NOFILE),
2223
"AVFMT_NOFILE is set. We only support writing to a file.");
23-
auto ffmpegRet =
24+
auto status =
2425
avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
2526
TORCH_CHECK(
26-
ffmpegRet >= 0,
27+
status >= 0,
2728
"avio_open failed: ",
28-
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
29+
getFFMPEGErrorStringFromErrorCode(status));
2930

3031
// We use the AVFormatContext's default codec for that
3132
// specificavcodec_parameters_from_context format/container.
@@ -39,7 +40,8 @@ Encoder::Encoder(int sampleRate, std::string_view fileName)
3940

4041
// This will use the default bit rate
4142
// TODO-ENCODING Should let user choose for compressed formats like mp3.
42-
avCodecContext_->bit_rate = 0;
43+
// avCodecContext_->bit_rate = 0;
44+
avCodecContext_->bit_rate = 24000;
4345

4446
// FFmpeg will raise a reasonably informative error if the desired sample rate
4547
// isn't supported by the encoder.
@@ -58,9 +60,8 @@ Encoder::Encoder(int sampleRate, std::string_view fileName)
5860
av_channel_layout_default(&channel_layout, 2);
5961
avCodecContext_->ch_layout = channel_layout;
6062

61-
ffmpegRet = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
62-
TORCH_CHECK(
63-
ffmpegRet == AVSUCCESS, getFFMPEGErrorStringFromErrorCode(ffmpegRet));
63+
status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
64+
TORCH_CHECK(status == AVSUCCESS, getFFMPEGErrorStringFromErrorCode(status));
6465

6566
TORCH_CHECK(
6667
avCodecContext_->frame_size > 0,
@@ -83,18 +84,18 @@ void Encoder::encode(const torch::Tensor& wf) {
8384
avFrame->format = avCodecContext_->sample_fmt;
8485
avFrame->sample_rate = avCodecContext_->sample_rate;
8586
avFrame->pts = 0;
86-
auto ffmpegRet =
87+
auto status =
8788
av_channel_layout_copy(&avFrame->ch_layout, &avCodecContext_->ch_layout);
8889
TORCH_CHECK(
89-
ffmpegRet == AVSUCCESS,
90+
status == AVSUCCESS,
9091
"Couldn't copy channel layout to avFrame: ",
91-
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
92+
getFFMPEGErrorStringFromErrorCode(status));
9293

93-
ffmpegRet = av_frame_get_buffer(avFrame.get(), 0);
94+
status = av_frame_get_buffer(avFrame.get(), 0);
9495
TORCH_CHECK(
95-
ffmpegRet == AVSUCCESS,
96+
status == AVSUCCESS,
9697
"Couldn't allocate avFrame's buffers: ",
97-
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
98+
getFFMPEGErrorStringFromErrorCode(status));
9899

99100
AutoAVPacket autoAVPacket;
100101

@@ -117,18 +118,18 @@ void Encoder::encode(const torch::Tensor& wf) {
117118
AV_NUM_DATA_POINTERS,
118119
" channels per frame.");
119120

120-
ffmpegRet = avformat_write_header(avFormatContext_.get(), nullptr);
121+
status = avformat_write_header(avFormatContext_.get(), nullptr);
121122
TORCH_CHECK(
122-
ffmpegRet == AVSUCCESS,
123+
status == AVSUCCESS,
123124
"Error in avformat_write_header: ",
124-
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
125+
getFFMPEGErrorStringFromErrorCode(status));
125126

126127
while (numEncodedSamples < numSamples) {
127-
ffmpegRet = av_frame_make_writable(avFrame.get());
128+
status = av_frame_make_writable(avFrame.get());
128129
TORCH_CHECK(
129-
ffmpegRet == AVSUCCESS,
130+
status == AVSUCCESS,
130131
"Couldn't make AVFrame writable: ",
131-
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
132+
getFFMPEGErrorStringFromErrorCode(status));
132133

133134
auto numSamplesToEncode =
134135
std::min(numSamplesPerFrame, numSamples - numEncodedSamples);
@@ -148,52 +149,51 @@ void Encoder::encode(const torch::Tensor& wf) {
148149

149150
encode_inner_loop(autoAVPacket, UniqueAVFrame(nullptr)); // flush
150151

151-
ffmpegRet = av_write_trailer(avFormatContext_.get());
152+
status = av_write_trailer(avFormatContext_.get());
152153
TORCH_CHECK(
153-
ffmpegRet == AVSUCCESS,
154+
status == AVSUCCESS,
154155
"Error in: av_write_trailer",
155-
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
156+
getFFMPEGErrorStringFromErrorCode(status));
156157
}
157158

158159
void Encoder::encode_inner_loop(
159160
AutoAVPacket& autoAVPacket,
160161
const UniqueAVFrame& avFrame) {
161-
auto ffmpegRet = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
162+
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
162163
TORCH_CHECK(
163-
ffmpegRet == AVSUCCESS,
164+
status == AVSUCCESS,
164165
"Error while sending frame: ",
165-
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
166+
getFFMPEGErrorStringFromErrorCode(status));
166167

167-
while (ffmpegRet >= 0) {
168+
while (status >= 0) {
168169
ReferenceAVPacket packet(autoAVPacket);
169-
ffmpegRet = avcodec_receive_packet(avCodecContext_.get(), packet.get());
170-
if (ffmpegRet == AVERROR(EAGAIN) || ffmpegRet == AVERROR_EOF) {
170+
status = avcodec_receive_packet(avCodecContext_.get(), packet.get());
171+
if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) {
171172
// TODO-ENCODING this is from TorchAudio, probably needed, but not sure.
172-
// if (ffmpegRet == AVERROR_EOF) {
173-
// ffmpegRet = av_interleaved_write_frame(avFormatContext_.get(),
173+
// if (status == AVERROR_EOF) {
174+
// status = av_interleaved_write_frame(avFormatContext_.get(),
174175
// nullptr); TORCH_CHECK(
175-
// ffmpegRet == AVSUCCESS,
176+
// status == AVSUCCESS,
176177
// "Failed to flush packet ",
177-
// getFFMPEGErrorStringFromErrorCode(ffmpegRet));
178+
// getFFMPEGErrorStringFromErrorCode(status));
178179
// }
179180
return;
180181
}
181182
TORCH_CHECK(
182-
ffmpegRet >= 0,
183+
status >= 0,
183184
"Error receiving packet: ",
184-
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
185+
getFFMPEGErrorStringFromErrorCode(status));
185186

186187
// TODO-ENCODING why are these 2 lines needed??
187188
av_packet_rescale_ts(
188189
packet.get(), avCodecContext_->time_base, avStream_->time_base);
189190
packet->stream_index = avStream_->index;
190191

191-
ffmpegRet =
192-
av_interleaved_write_frame(avFormatContext_.get(), packet.get());
192+
status = av_interleaved_write_frame(avFormatContext_.get(), packet.get());
193193
TORCH_CHECK(
194-
ffmpegRet == AVSUCCESS,
194+
status == AVSUCCESS,
195195
"Error in av_interleaved_write_frame: ",
196-
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
196+
getFFMPEGErrorStringFromErrorCode(status));
197197
}
198198
}
199199
} // namespace facebook::torchcodec

test/decoders/test_ops.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -942,16 +942,33 @@ def decode(self, source) -> torch.Tensor:
942942

943943
def test_round_trip(self, tmp_path):
944944
asset = SINE_MONO_S32
945-
source_samples = self.decode(asset)
946945

947-
output_file = tmp_path / "output.mp3"
946+
encoded_by_ffmpeg = tmp_path / "ffmpeg_output.mp3"
947+
encoded_by_us = tmp_path / "our_output.mp3"
948+
949+
command = [
950+
"ffmpeg",
951+
"-i",
952+
str(asset.path),
953+
# '-vn',
954+
# '-ar', '44100', # Set audio sampling rate
955+
# '-ac', '2', # Set number of audio channels
956+
# '-b:a', '192k', # Set audio bitrate
957+
str(encoded_by_ffmpeg),
958+
]
959+
subprocess.run(command, check=True)
960+
948961
encoder = create_encoder(
949-
sample_rate=asset.sample_rate, filename=str(output_file)
962+
sample_rate=asset.sample_rate, filename=str(encoded_by_us)
950963
)
951-
encode(encoder, source_samples)
952964

953-
round_trip_samples = self.decode(output_file)
954-
torch.testing.assert_close(source_samples, round_trip_samples)
965+
encode(encoder, self.decode(asset))
966+
967+
print(encoded_by_ffmpeg)
968+
print(encoded_by_us)
969+
from_ffmpeg = self.decode(encoded_by_ffmpeg)
970+
from_us = self.decode(encoded_by_us)
971+
torch.testing.assert_close(from_us, from_ffmpeg)
955972

956973

957974
if __name__ == "__main__":

0 commit comments

Comments
 (0)