Skip to content

Commit 3cec761

Browse files
committed
Add tests
1 parent eb2a86c commit 3cec761

File tree

5 files changed

+94
-78
lines changed

5 files changed

+94
-78
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,19 @@ Encoder::~Encoder() {}
1212

1313
// TODO-ENCODING: disable ffmpeg logs by default
1414

15-
Encoder::Encoder(int sampleRate, std::string_view fileName)
16-
: sampleRate_(sampleRate) {
15+
Encoder::Encoder(
16+
const torch::Tensor wf,
17+
int sampleRate,
18+
std::string_view fileName)
19+
: wf_(wf), sampleRate_(sampleRate) {
1720
AVFormatContext* avFormatContext = nullptr;
1821
avformat_alloc_output_context2(
1922
&avFormatContext, nullptr, nullptr, fileName.data());
2023
TORCH_CHECK(avFormatContext != nullptr, "Couldn't allocate AVFormatContext.");
2124
avFormatContext_.reset(avFormatContext);
2225

26+
// TODO-ENCODING: Should also support encoding into bytes (use
27+
// AVIOBytesContext)
2328
TORCH_CHECK(
2429
!(avFormatContext->oformat->flags & AVFMT_NOFILE),
2530
"AVFMT_NOFILE is set. We only support writing to a file.");
@@ -31,7 +36,7 @@ Encoder::Encoder(int sampleRate, std::string_view fileName)
3136
getFFMPEGErrorStringFromErrorCode(status));
3237

3338
// We use the AVFormatContext's default codec for that
34-
// specificavcodec_parameters_from_context format/container.
39+
// specific format/container.
3540
const AVCodec* avCodec =
3641
avcodec_find_encoder(avFormatContext_->oformat->audio_codec);
3742
TORCH_CHECK(avCodec != nullptr, "Codec not found");
@@ -40,9 +45,10 @@ Encoder::Encoder(int sampleRate, std::string_view fileName)
4045
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
4146
avCodecContext_.reset(avCodecContext);
4247

43-
// This will use the default bit rate
44-
// TODO-ENCODING Should let user choose for compressed formats like mp3.
45-
// avCodecContext_->bit_rate = 64000;
48+
// TODO-ENCODING I think this sets the bit rate to the minimum supported.
49+
// That's not what the ffmpeg CLI would choose by default, so we should try to
50+
// do the same.
51+
// TODO-ENCODING Should also let user choose for compressed formats like mp3.
4652
avCodecContext_->bit_rate = 0;
4753

4854
// FFmpeg will raise a reasonably informative error if the desired sample rate
@@ -58,8 +64,19 @@ Encoder::Encoder(int sampleRate, std::string_view fileName)
5864
// libswresample.
5965
avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP;
6066

67+
auto numChannels = wf_.sizes()[0];
68+
TORCH_CHECK(
69+
// TODO-ENCODING is this even true / needed? We can probably support more
70+
// with non-planar data?
71+
numChannels <= AV_NUM_DATA_POINTERS,
72+
"Trying to encode ",
73+
numChannels,
74+
" channels, but FFmpeg only supports ",
75+
AV_NUM_DATA_POINTERS,
76+
" channels per frame.");
77+
6178
AVChannelLayout channel_layout;
62-
av_channel_layout_default(&channel_layout, 2);
79+
av_channel_layout_default(&channel_layout, numChannels);
6380
avCodecContext_->ch_layout = channel_layout;
6481

6582
status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
@@ -79,7 +96,7 @@ Encoder::Encoder(int sampleRate, std::string_view fileName)
7996
avcodec_parameters_from_context(avStream_->codecpar, avCodecContext_.get());
8097
}
8198

82-
void Encoder::encode(const torch::Tensor& wf) {
99+
void Encoder::encode() {
83100
UniqueAVFrame avFrame(av_frame_alloc());
84101
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
85102
avFrame->nb_samples = avCodecContext_->frame_size;
@@ -101,24 +118,13 @@ void Encoder::encode(const torch::Tensor& wf) {
101118

102119
AutoAVPacket autoAVPacket;
103120

104-
uint8_t* pWf = static_cast<uint8_t*>(wf.data_ptr());
105-
auto numChannels = wf.sizes()[0];
106-
auto numSamples = wf.sizes()[1]; // per channel
121+
uint8_t* pwf = static_cast<uint8_t*>(wf_.data_ptr());
122+
auto numSamples = wf_.sizes()[1]; // per channel
107123
auto numEncodedSamples = 0; // per channel
108124
auto numSamplesPerFrame =
109125
static_cast<long>(avCodecContext_->frame_size); // per channel
110-
auto numBytesPerSample = wf.element_size();
111-
auto numBytesPerChannel = wf.sizes()[1] * numBytesPerSample;
112-
113-
TORCH_CHECK(
114-
// TODO-ENCODING is this even true / needed? We can probably support more
115-
// with non-planar data?
116-
numChannels <= AV_NUM_DATA_POINTERS,
117-
"Trying to encode ",
118-
numChannels,
119-
" channels, but FFmpeg only supports ",
120-
AV_NUM_DATA_POINTERS,
121-
" channels per frame.");
126+
auto numBytesPerSample = wf_.element_size();
127+
auto numBytesPerChannel = numSamples * numBytesPerSample;
122128

123129
status = avformat_write_header(avFormatContext_.get(), nullptr);
124130
TORCH_CHECK(
@@ -136,16 +142,22 @@ void Encoder::encode(const torch::Tensor& wf) {
136142
auto numSamplesToEncode =
137143
std::min(numSamplesPerFrame, numSamples - numEncodedSamples);
138144
auto numBytesToEncode = numSamplesToEncode * numBytesPerSample;
139-
avFrame->nb_samples = std::min(static_cast<int64_t>(avCodecContext_->frame_size), numSamplesToEncode);
140145

141-
for (int ch = 0; ch < numChannels; ch++) {
146+
for (int ch = 0; ch < wf_.sizes()[0]; ch++) {
142147
memcpy(
143-
avFrame->data[ch], pWf + ch * numBytesPerChannel, numBytesToEncode);
148+
avFrame->data[ch], pwf + ch * numBytesPerChannel, numBytesToEncode);
144149
}
145-
pWf += numBytesToEncode;
150+
pwf += numBytesToEncode;
151+
152+
// Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size so
153+
// that the frame buffers are allocated to a big enough size. Here, we reset
154+
// it to the exact number of samples that need to be encoded, otherwise the
155+
// encoded frame would contain more samples than necessary and our results
156+
// wouldn't match the ffmpeg CLI.
157+
avFrame->nb_samples = numSamplesToEncode;
146158
encode_inner_loop(autoAVPacket, avFrame);
147159

148-
avFrame->pts += avFrame->nb_samples;
160+
avFrame->pts += numSamplesToEncode;
149161
numEncodedSamples += numSamplesToEncode;
150162
}
151163
TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong.");
@@ -163,11 +175,6 @@ void Encoder::encode_inner_loop(
163175
AutoAVPacket& autoAVPacket,
164176
const UniqueAVFrame& avFrame) {
165177
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
166-
// if (avFrame.get()) {
167-
// printf("Sending frame with %d samples\n", avFrame->nb_samples);
168-
// } else {
169-
// printf("Flushing\n");
170-
// }
171178
TORCH_CHECK(
172179
status == AVSUCCESS,
173180
"Error while sending frame: ",

src/torchcodec/_core/Encoder.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,8 @@ class Encoder {
77
public:
88
~Encoder();
99

10-
// TODO Are we OK passing a string_view to the constructor?
11-
// TODO fileName should be optional.
12-
// TODO doesn't make much sense to pass fileName and the wf tensor in 2
13-
// different calls. Same with sampleRate.
14-
Encoder(int sampleRate, std::string_view fileName);
15-
void encode(const torch::Tensor& wf);
10+
Encoder(const torch::Tensor wf, int sampleRate, std::string_view fileName);
11+
void encode();
1612

1713
private:
1814
void encode_inner_loop(
@@ -31,5 +27,6 @@ class Encoder {
3127
// resample the waveform internally to match them, but that's not in scope for
3228
// an initial version (if at all).
3329
int sampleRate_;
30+
const torch::Tensor wf_;
3431
};
3532
} // namespace facebook::torchcodec

src/torchcodec/_core/custom_ops.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
2828
m.impl_abstract_pystub(
2929
"torchcodec._core.ops", "//pytorch/torchcodec:torchcodec");
3030
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
31-
m.def("create_encoder(int sample_rate, str filename) -> Tensor");
32-
m.def("encode(Tensor(a!) encoder, Tensor wf) -> ()");
31+
m.def("create_encoder(Tensor wf, int sample_rate, str filename) -> Tensor");
32+
m.def("encode(Tensor(a!) encoder) -> ()");
3333
m.def(
3434
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
3535
m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
@@ -194,15 +194,18 @@ at::Tensor create_from_file(
194194
return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
195195
}
196196

197-
at::Tensor create_encoder(int64_t sample_rate, std::string_view file_name) {
197+
at::Tensor create_encoder(
198+
const at::Tensor wf,
199+
int64_t sample_rate,
200+
std::string_view file_name) {
198201
std::unique_ptr<Encoder> uniqueEncoder =
199-
std::make_unique<Encoder>(static_cast<int>(sample_rate), file_name);
202+
std::make_unique<Encoder>(wf, static_cast<int>(sample_rate), file_name);
200203
return wrapEncoderPointerToTensor(std::move(uniqueEncoder));
201204
}
202205

203-
void encode(at::Tensor& encoder, const at::Tensor& wf) {
206+
void encode(at::Tensor& encoder) {
204207
auto encoder_ = unwrapTensorToGetEncoder(encoder);
205-
encoder_->encode(wf);
208+
encoder_->encode();
206209
}
207210

208211
// Create a VideoDecoder from the actual bytes of a video and wrap the pointer

src/torchcodec/_core/ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,14 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
160160

161161

162162
@register_fake("torchcodec_ns::create_encoder")
163-
def create_encoder_abstract(sample_rate: int, filename: str) -> torch.Tensor:
163+
def create_encoder_abstract(
164+
wf: torch.Tensor, sample_rate: int, filename: str
165+
) -> torch.Tensor:
164166
return torch.empty([], dtype=torch.long)
165167

166168

167169
@register_fake("torchcodec_ns::encode")
168-
def encode_abstract(encoder: torch.Tensor, wf: torch.Tensor) -> torch.Tensor:
170+
def encode_abstract(encoder: torch.Tensor) -> torch.Tensor:
169171
return torch.empty([], dtype=torch.long)
170172

171173

test/decoders/test_ops.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -940,48 +940,55 @@ def decode(self, source) -> torch.Tensor:
940940
)
941941
return frames
942942

943-
# def test_round_trip(self, tmp_path):
944-
# asset = NASA_AUDIO_MP3
945-
946-
# encoded_path = tmp_path / "output.mp3"
947-
# encoder = create_encoder(
948-
# sample_rate=asset.sample_rate, filename=str(encoded_path)
949-
# )
950-
951-
# source_samples = self.decode(asset)
952-
# encode(encoder, source_samples)
943+
def test_round_trip(self, tmp_path):
944+
# Check that decode(encode(samples)) == samples
945+
asset = NASA_AUDIO_MP3
946+
source_samples = self.decode(asset)
953947

954-
# torch.testing.assert_close(self.decode(encoded_path), source_samples)
948+
encoded_path = tmp_path / "output.mp3"
949+
encoder = create_encoder(
950+
wf=source_samples, sample_rate=asset.sample_rate, filename=str(encoded_path)
951+
)
952+
encode(encoder)
955953

956-
def test_against_cli(self, tmp_path):
954+
# TODO-ENCODING: tol should be stricter. We need to increase the encoded
955+
# bitrate, and / or encode into a lossless format.
956+
torch.testing.assert_close(
957+
self.decode(encoded_path), source_samples, rtol=0, atol=0.07
958+
)
957959

958-
asset = NASA_AUDIO_MP3
960+
# TODO-ENCODING: test more encoding formats
961+
@pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32))
962+
def test_against_cli(self, asset, tmp_path):
963+
# Encodes samples with our encoder and with the FFmpeg CLI, and checks
964+
# that both decoded outputs are equal
959965

960966
encoded_by_ffmpeg = tmp_path / "ffmpeg_output.mp3"
961967
encoded_by_us = tmp_path / "our_output.mp3"
962968

963-
command = [
964-
"ffmpeg",
965-
"-i",
966-
str(asset.path),
967-
# '-vn',
968-
# '-ar', '16000', # Set audio sampling rate
969-
# '-ac', '2', # Set number of audio channels
970-
# '-b:a', '192k', # Set audio bitrate
971-
'-b:a', '0', # Set audio bitrate
972-
str(encoded_by_ffmpeg),
973-
]
974-
subprocess.run(command, check=True)
969+
subprocess.run(
970+
[
971+
"ffmpeg",
972+
"-i",
973+
str(asset.path),
974+
"-b:a",
975+
"0", # bitrate hardcoded to 0, see corresponding TODO.
976+
str(encoded_by_ffmpeg),
977+
],
978+
capture_output=True,
979+
check=True,
980+
)
975981

976982
encoder = create_encoder(
977-
sample_rate=asset.sample_rate, filename=str(encoded_by_us)
983+
wf=self.decode(asset),
984+
sample_rate=asset.sample_rate,
985+
filename=str(encoded_by_us),
978986
)
987+
encode(encoder)
979988

980-
encode(encoder, self.decode(asset))
981-
982-
from_ffmpeg = self.decode(encoded_by_ffmpeg)
983-
from_us = self.decode(encoded_by_us)
984-
torch.testing.assert_close(from_us, from_ffmpeg)
989+
torch.testing.assert_close(
990+
self.decode(encoded_by_ffmpeg), self.decode(encoded_by_us)
991+
)
985992

986993

987994
if __name__ == "__main__":

0 commit comments

Comments
 (0)