Skip to content

Commit 6e294a9

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into encode-twice
2 parents e4f05ce + d717d6c commit 6e294a9

File tree

4 files changed

+74
-26
lines changed

4 files changed

+74
-26
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@ torch::Tensor validateWf(torch::Tensor wf) {
1313
wf.dtype() == torch::kFloat32,
1414
"waveform must have float32 dtype, got ",
1515
wf.dtype());
16-
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
17-
// planar (fltp).
1816
TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim());
19-
return wf;
17+
return wf.contiguous();
2018
}
2119

2220
void validateSampleRate(const AVCodec& avCodec, int sampleRate) {

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -878,8 +878,9 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
878878
while (!finished) {
879879
try {
880880
UniqueAVFrame avFrame =
881-
decodeAVFrame([startPts](const UniqueAVFrame& avFrame) {
882-
return startPts < avFrame->pts + getDuration(avFrame);
881+
decodeAVFrame([startPts, stopPts](const UniqueAVFrame& avFrame) {
882+
return startPts < avFrame->pts + getDuration(avFrame) &&
883+
stopPts > avFrame->pts;
883884
});
884885
auto frameOutput = convertAVFrameToFrameOutput(avFrame);
885886
if (!firstFramePtsSeconds.has_value()) {
@@ -908,9 +909,12 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
908909
TORCH_CHECK(
909910
frames.size() > 0 && firstFramePtsSeconds.has_value(),
910911
"No audio frames were decoded. ",
911-
"This is probably because start_seconds is too high? ",
912-
"Current value is ",
913-
startSeconds);
912+
"This is probably because start_seconds is too high(",
913+
startSeconds,
914+
"),",
915+
"or because stop_seconds(",
916+
stopSecondsOptional,
917+
") is too low.");
914918

915919
return AudioFramesOutput{torch::cat(frames, 1), *firstFramePtsSeconds};
916920
}

test/test_decoders.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,8 @@ def test_start_equals_stop(self, asset):
11121112

11131113
def test_frame_start_is_not_zero(self):
11141114
# For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.138125.
1115-
# So if we request start = 0.05, we shouldn't be truncating anything.
1115+
# So if we request (start, stop) = (0.05, None), we shouldn't be
1116+
# truncating anything.
11161117

11171118
asset = NASA_AUDIO_MP3
11181119
start_seconds = 0.05 # this is less than the first frame's pts
@@ -1128,6 +1129,35 @@ def test_frame_start_is_not_zero(self):
11281129
reference_frames = asset.get_frame_data_by_range(start=0, stop=stop_frame_index)
11291130
torch.testing.assert_close(samples.data, reference_frames)
11301131

1132+
# Non-regression test for https://github.com/pytorch/torchcodec/issues/567
1133+
# If we ask for start < stop <= first_frame_pts, we should raise.
1134+
with pytest.raises(RuntimeError, match="No audio frames were decoded"):
1135+
decoder.get_samples_played_in_range(start_seconds=0, stop_seconds=0.05)
1136+
1137+
first_frame_pts_seconds = asset.get_frame_info(idx=0).pts_seconds
1138+
with pytest.raises(RuntimeError, match="No audio frames were decoded"):
1139+
decoder.get_samples_played_in_range(
1140+
start_seconds=0, stop_seconds=first_frame_pts_seconds
1141+
)
1142+
1143+
# Documenting an edge case: we ask for samples barely beyond the start
1144+
# of the first frame. The C++ decoder returns the first frame, which
1145+
# gets (correctly!) truncated by the AudioDecoder, and we end up with
1146+
# empty data.
1147+
samples = decoder.get_samples_played_in_range(
1148+
start_seconds=0, stop_seconds=first_frame_pts_seconds + 1e-5
1149+
)
1150+
assert samples.data.shape == (2, 0)
1151+
assert samples.pts_seconds == first_frame_pts_seconds
1152+
assert samples.duration_seconds == 0
1153+
1154+
# if we ask for a little bit more samples, we get non-empty data
1155+
samples = decoder.get_samples_played_in_range(
1156+
start_seconds=0, stop_seconds=first_frame_pts_seconds + 1e-3
1157+
)
1158+
assert samples.data.shape == (2, 8)
1159+
assert samples.pts_seconds == first_frame_pts_seconds
1160+
11311161
def test_single_channel(self):
11321162
asset = SINE_MONO_S32
11331163
decoder = AudioDecoder(asset.path)

test/test_ops.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -884,23 +884,6 @@ def test_pts(self, asset):
884884

885885
assert pts_seconds == start_seconds
886886

887-
def test_decode_before_frame_start(self):
888-
# Test illustrating bug described in
889-
# https://github.com/pytorch/torchcodec/issues/567
890-
asset = NASA_AUDIO_MP3
891-
892-
decoder = create_from_file(str(asset.path), seek_mode="approximate")
893-
add_audio_stream(decoder)
894-
895-
frames, *_ = get_frames_by_pts_in_range_audio(
896-
decoder, start_seconds=0, stop_seconds=0.05
897-
)
898-
all_frames, *_ = get_frames_by_pts_in_range_audio(
899-
decoder, start_seconds=0, stop_seconds=None
900-
)
901-
# TODO fix this. `frames` should be empty.
902-
torch.testing.assert_close(frames, all_frames)
903-
904887
def test_sample_rate_conversion(self):
905888
def get_all_frames(asset, sample_rate=None, stop_seconds=None):
906889
decoder = create_from_file(str(asset.path), seek_mode="approximate")
@@ -1284,6 +1267,39 @@ def test_encode_to_tensor_long_output(self):
12841267

12851268
torch.testing.assert_close(self.decode(encoded_tensor), samples)
12861269

1270+
def test_contiguity(self):
1271+
# Ensure that 2 waveforms with the same values are encoded in the same
1272+
# way, regardless of their memory layout. Here we encode 2 equal
1273+
# waveforms, one is row-aligned while the other is column-aligned.
1274+
1275+
num_samples = 10_000 # per channel
1276+
contiguous_samples = torch.rand(2, num_samples).contiguous()
1277+
assert contiguous_samples.stride() == (num_samples, 1)
1278+
1279+
encoded_from_contiguous = encode_audio_to_tensor(
1280+
wf=contiguous_samples,
1281+
sample_rate=16_000,
1282+
format="flac",
1283+
bit_rate=44_000,
1284+
)
1285+
non_contiguous_samples = contiguous_samples.T.contiguous().T
1286+
assert non_contiguous_samples.stride() == (1, 2)
1287+
1288+
torch.testing.assert_close(
1289+
contiguous_samples, non_contiguous_samples, rtol=0, atol=0
1290+
)
1291+
1292+
encoded_from_non_contiguous = encode_audio_to_tensor(
1293+
wf=non_contiguous_samples,
1294+
sample_rate=16_000,
1295+
format="flac",
1296+
bit_rate=44_000,
1297+
)
1298+
1299+
torch.testing.assert_close(
1300+
encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0
1301+
)
1302+
12871303

12881304
if __name__ == "__main__":
12891305
pytest.main()

0 commit comments

Comments
 (0)