Skip to content

Commit eb2a86c

Browse files
committed
Stuff
1 parent 691dde7 commit eb2a86c

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ namespace facebook::torchcodec {
1010

1111
Encoder::~Encoder() {}
1212

13+
// TODO-ENCODING: disable ffmpeg logs by default
14+
1315
Encoder::Encoder(int sampleRate, std::string_view fileName)
1416
: sampleRate_(sampleRate) {
1517
AVFormatContext* avFormatContext = nullptr;
@@ -40,8 +42,8 @@ Encoder::Encoder(int sampleRate, std::string_view fileName)
4042

4143
// This will use the default bit rate
4244
// TODO-ENCODING Should let user choose for compressed formats like mp3.
43-
// avCodecContext_->bit_rate = 0;
44-
avCodecContext_->bit_rate = 24000;
45+
// avCodecContext_->bit_rate = 64000;
46+
avCodecContext_->bit_rate = 0;
4547

4648
// FFmpeg will raise a reasonably informative error if the desired sample rate
4749
// isn't supported by the encoder.
@@ -134,6 +136,7 @@ void Encoder::encode(const torch::Tensor& wf) {
134136
auto numSamplesToEncode =
135137
std::min(numSamplesPerFrame, numSamples - numEncodedSamples);
136138
auto numBytesToEncode = numSamplesToEncode * numBytesPerSample;
139+
avFrame->nb_samples = std::min(static_cast<int64_t>(avCodecContext_->frame_size), numSamplesToEncode);
137140

138141
for (int ch = 0; ch < numChannels; ch++) {
139142
memcpy(
@@ -160,6 +163,11 @@ void Encoder::encode_inner_loop(
160163
AutoAVPacket& autoAVPacket,
161164
const UniqueAVFrame& avFrame) {
162165
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
166+
// if (avFrame.get()) {
167+
// printf("Sending frame with %d samples\n", avFrame->nb_samples);
168+
// } else {
169+
// printf("Flushing\n");
170+
// }
163171
TORCH_CHECK(
164172
status == AVSUCCESS,
165173
"Error while sending frame: ",

test/decoders/test_ops.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -940,8 +940,22 @@ def decode(self, source) -> torch.Tensor:
940940
)
941941
return frames
942942

943-
def test_round_trip(self, tmp_path):
944-
asset = SINE_MONO_S32
943+
# def test_round_trip(self, tmp_path):
944+
# asset = NASA_AUDIO_MP3
945+
946+
# encoded_path = tmp_path / "output.mp3"
947+
# encoder = create_encoder(
948+
# sample_rate=asset.sample_rate, filename=str(encoded_path)
949+
# )
950+
951+
# source_samples = self.decode(asset)
952+
# encode(encoder, source_samples)
953+
954+
# torch.testing.assert_close(self.decode(encoded_path), source_samples)
955+
956+
def test_against_cli(self, tmp_path):
957+
958+
asset = NASA_AUDIO_MP3
945959

946960
encoded_by_ffmpeg = tmp_path / "ffmpeg_output.mp3"
947961
encoded_by_us = tmp_path / "our_output.mp3"
@@ -951,9 +965,10 @@ def test_round_trip(self, tmp_path):
951965
"-i",
952966
str(asset.path),
953967
# '-vn',
954-
# '-ar', '44100', # Set audio sampling rate
968+
# '-ar', '16000', # Set audio sampling rate
955969
# '-ac', '2', # Set number of audio channels
956970
# '-b:a', '192k', # Set audio bitrate
971+
'-b:a', '0', # Set audio bitrate
957972
str(encoded_by_ffmpeg),
958973
]
959974
subprocess.run(command, check=True)
@@ -964,8 +979,6 @@ def test_round_trip(self, tmp_path):
964979

965980
encode(encoder, self.decode(asset))
966981

967-
print(encoded_by_ffmpeg)
968-
print(encoded_by_us)
969982
from_ffmpeg = self.decode(encoded_by_ffmpeg)
970983
from_us = self.decode(encoded_by_us)
971984
torch.testing.assert_close(from_us, from_ffmpeg)

0 commit comments

Comments
 (0)