Skip to content

Commit 17340a6

Browse files
committed
Properly set frames pts
1 parent 17cd1d8 commit 17340a6

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,6 @@ void AudioEncoder::encode() {
318318
encodeFrameThroughFifo(autoAVPacket, convertedAVFrame);
319319

320320
numEncodedSamples += numSamplesToEncode;
321-
avFrame->pts += static_cast<int64_t>(numSamplesToEncode);
322321
}
323322
TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong.");
324323

@@ -405,6 +404,11 @@ void AudioEncoder::encodeFrameThroughFifo(
405404
void AudioEncoder::encodeFrame(
406405
AutoAVPacket& autoAVPacket,
407406
const UniqueAVFrame& avFrame) {
407+
if (avFrame != nullptr) {
408+
avFrame->pts = lastEncodedAVFramePts_;
409+
lastEncodedAVFramePts_ += avFrame->nb_samples;
410+
}
411+
408412
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
409413
TORCH_CHECK(
410414
status == AVSUCCESS,

src/torchcodec/_core/Encoder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,6 @@ class AudioEncoder {
5858
std::unique_ptr<AVIOToTensorContext> avioContextHolder_;
5959

6060
bool encodeWasCalled_ = false;
61+
int64_t lastEncodedAVFramePts_ = 0;
6162
};
6263
} // namespace facebook::torchcodec

test/test_encoders.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def test_against_cli(
229229
["ffmpeg", "-i", str(asset.path)]
230230
+ (["-b:a", f"{bit_rate}"] if bit_rate is not None else [])
231231
+ (["-ac", f"{num_channels}"] if num_channels is not None else [])
232-
+ (["-ar", f"{sample_rate}"] if sample_rate is not None else [])
232+
+ ["-ar", f"{sample_rate}"]
233233
+ [
234234
str(encoded_by_ffmpeg),
235235
],
@@ -247,17 +247,19 @@ def test_against_cli(
247247
else:
248248
encoded_by_us = encoder.to_tensor(format=format, **params)
249249

250-
# captured = capfd.readouterr()
251-
# if format == "wav":
252-
# assert "Timestamps are unset in a packet" not in captured.err
253-
# if format == "mp3":
254-
# assert "Queue input is backward in time" not in captured.err
255-
# if format in ("flac", "wav"):
256-
# assert "Encoder did not produce proper pts" not in captured.err
257-
# if format in ("flac", "mp3"):
258-
# assert "Application provided invalid" not in captured.err
259-
250+
captured = capfd.readouterr()
260251
if format == "wav":
252+
assert "Timestamps are unset in a packet" not in captured.err
253+
if format == "mp3":
254+
assert "Queue input is backward in time" not in captured.err
255+
if format in ("flac", "wav"):
256+
assert "Encoder did not produce proper pts" not in captured.err
257+
if format in ("flac", "mp3"):
258+
assert "Application provided invalid" not in captured.err
259+
260+
if sample_rate != asset.sample_rate:
261+
rtol, atol = 0, 1e-3
262+
elif format == "wav":
261263
rtol, atol = 0, 1e-4
262264
elif format == "mp3" and asset is SINE_MONO_S32 and num_channels == 2:
263265
# Not sure why, this one needs slightly higher tol. With default
@@ -268,7 +270,6 @@ def test_against_cli(
268270
else:
269271
rtol, atol = None, None
270272

271-
rtol, atol = 0, 1e-3
272273
samples_by_us = self.decode(encoded_by_us)
273274
samples_by_ffmpeg = self.decode(encoded_by_ffmpeg)
274275
torch.testing.assert_close(

0 commit comments

Comments
 (0)