Skip to content

Commit 83c75b5

Browse files
committed
Fix mp3 tests
1 parent 50938fe commit 83c75b5

File tree

5 files changed

+62
-29
lines changed

5 files changed

+62
-29
lines changed

src/torchcodec/_core/AVIOBytesContext.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ int64_t AVIOBytesContext::seek(void* opaque, int64_t offset, int whence) {
6969

7070
AVIOToTensorContext::AVIOToTensorContext()
7171
: dataContext_{torch::empty({OUTPUT_TENSOR_SIZE}, {torch::kUInt8}), 0} {
72-
createAVIOContext(nullptr, &write, nullptr, &dataContext_);
72+
createAVIOContext(nullptr, &write, &seek, &dataContext_);
7373
}
7474

7575
// The signature of this function is defined by FFMPEG.
@@ -84,6 +84,26 @@ int AVIOToTensorContext::write(void* opaque, uint8_t* buf, int buf_size) {
8484
return buf_size;
8585
}
8686

87+
// The signature of this function is defined by FFMPEG.
88+
int64_t AVIOToTensorContext::seek(void* opaque, int64_t offset, int whence) {
89+
auto dataContext = static_cast<DataContext*>(opaque);
90+
int64_t ret = -1;
91+
92+
switch (whence) {
93+
case AVSEEK_SIZE:
94+
ret = dataContext->outputTensor.numel();
95+
break;
96+
case SEEK_SET:
97+
dataContext->current = offset;
98+
ret = offset;
99+
break;
100+
default:
101+
break;
102+
}
103+
104+
return ret;
105+
}
106+
87107
torch::Tensor AVIOToTensorContext::getOutputTensor() {
88108
return dataContext_.outputTensor.narrow(
89109
/*dim=*/0, /*start=*/0, /*length=*/dataContext_.current);

src/torchcodec/_core/AVIOBytesContext.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class AVIOToTensorContext : public AVIOContextHolder {
4646

4747
static const int OUTPUT_TENSOR_SIZE = 5'000'000; // TODO-ENCODING handle this
4848
static int write(void* opaque, uint8_t* buf, int buf_size);
49+
static int64_t seek(void* opaque, int64_t offset, int whence);
4950

5051
DataContext dataContext_;
5152
};

src/torchcodec/_core/AVIOContextHolder.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ namespace facebook::torchcodec {
1818
// UniqueAVIOContext, as the AVIOContext points to a buffer which must be
1919
// freed.
2020
// 2. It is a base class for AVIOContext specializations. When specializing a
21-
// AVIOContext, we need to provide:
22-
// 1. - For decoding: A read callback function and a seek callback
23-
// function.
24-
// - For encoding: A write callback function.
25-
// 2. A pointer to some context object that has the same lifetime as the
21+
// AVIOContext, we need to provide four things:
22+
// 1. A read callback function, for decoding.
23+
// 2. A seek callback function, for decoding and encoding.
24+
// 3. A write callback function, for encoding>
25+
// 4. A pointer to some context object that has the same lifetime as the
2626
// AVIOContext itself. This context object holds the custom state that
2727
// tracks the custom behavior of reading, seeking and writing. It is
2828
// provided upon AVIOContext creation and to the read, seek and

src/torchcodec/_core/Encoder.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ AudioEncoder::AudioEncoder(
5757
TORCH_CHECK(
5858
wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim());
5959

60-
avioContextHolder_ = std::make_unique<AVIOToTensorContext>();
61-
6260
setFFmpegLogLevel();
6361
AVFormatContext* avFormatContext = nullptr;
6462
int status = AVSUCCESS;
@@ -84,6 +82,7 @@ AudioEncoder::AudioEncoder(
8482
"avio_open failed: ",
8583
getFFMPEGErrorStringFromErrorCode(status));
8684
} else {
85+
avioContextHolder_ = std::make_unique<AVIOToTensorContext>();
8786
avFormatContext->pb = avioContextHolder_->getAVIOContext();
8887
}
8988

test/test_ops.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,12 +1162,9 @@ def test_round_trip(self, encode_method, output_format, tmp_path):
11621162

11631163
@pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI")
11641164
@pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32))
1165-
@pytest.mark.parametrize(
1166-
"encode_method", (encode_audio_to_file, encode_audio_to_tensor)
1167-
)
11681165
@pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999))
11691166
@pytest.mark.parametrize("output_format", ("mp3", "wav", "flac"))
1170-
def test_against_cli(self, asset, encode_method, bit_rate, output_format, tmp_path):
1167+
def test_against_cli(self, asset, bit_rate, output_format, tmp_path):
11711168
# Encodes samples with our encoder and with the FFmpeg CLI, and checks
11721169
# that both decoded outputs are equal
11731170

@@ -1186,24 +1183,14 @@ def test_against_cli(self, asset, encode_method, bit_rate, output_format, tmp_pa
11861183
check=True,
11871184
)
11881185

1189-
if encode_method is encode_audio_to_file:
1190-
encoded_by_us = tmp_path / f"our_output.{output_format}"
1191-
encode_audio_to_file(
1192-
wf=self.decode(asset),
1193-
sample_rate=asset.sample_rate,
1194-
filename=str(encoded_by_us),
1195-
bit_rate=bit_rate,
1196-
)
1197-
else:
1198-
encoded_by_us = encode_audio_to_tensor(
1199-
wf=self.decode(asset),
1200-
sample_rate=asset.sample_rate,
1201-
format=output_format,
1202-
bit_rate=bit_rate,
1203-
)
1186+
encoded_by_us = tmp_path / f"our_output.{output_format}"
1187+
encode_audio_to_file(
1188+
wf=self.decode(asset),
1189+
sample_rate=asset.sample_rate,
1190+
filename=str(encoded_by_us),
1191+
bit_rate=bit_rate,
1192+
)
12041193

1205-
if output_format == "mp3" and encode_method is encode_audio_to_tensor:
1206-
pytest.skip("TODO-ENCODING investigate, decoded lengths are slightly different")
12071194
rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None)
12081195
torch.testing.assert_close(
12091196
self.decode(encoded_by_ffmpeg),
@@ -1212,6 +1199,32 @@ def test_against_cli(self, asset, encode_method, bit_rate, output_format, tmp_pa
12121199
atol=atol,
12131200
)
12141201

1202+
@pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32))
1203+
@pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999))
1204+
@pytest.mark.parametrize("output_format", ("mp3", "wav", "flac"))
1205+
def test_tensor_against_file(self, asset, bit_rate, output_format, tmp_path):
1206+
if get_ffmpeg_major_version() == 4 and output_format == "wav":
1207+
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files")
1208+
1209+
encoded_file = tmp_path / f"our_output.{output_format}"
1210+
encode_audio_to_file(
1211+
wf=self.decode(asset),
1212+
sample_rate=asset.sample_rate,
1213+
filename=str(encoded_file),
1214+
bit_rate=bit_rate,
1215+
)
1216+
1217+
encoded_tensor = encode_audio_to_tensor(
1218+
wf=self.decode(asset),
1219+
sample_rate=asset.sample_rate,
1220+
format=output_format,
1221+
bit_rate=bit_rate,
1222+
)
1223+
1224+
torch.testing.assert_close(
1225+
self.decode(encoded_file), self.decode(encoded_tensor)
1226+
)
1227+
12151228

12161229
if __name__ == "__main__":
12171230
pytest.main()

0 commit comments

Comments
 (0)