Skip to content

Commit aad9c7d

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into encoding_num_channels
2 parents 52d624b + f1e5e91 commit aad9c7d

File tree

6 files changed

+106
-4
lines changed

6 files changed

+106
-4
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,18 @@ AudioEncoder::AudioEncoder(
126126
TORCH_CHECK(
127127
avFormatContext != nullptr,
128128
"Couldn't allocate AVFormatContext. ",
129-
"Check the desired extension? ",
129+
"The destination file is ",
130+
fileName,
131+
", check the desired extension? ",
130132
getFFMPEGErrorStringFromErrorCode(status));
131133
avFormatContext_.reset(avFormatContext);
132134

133135
status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
134136
TORCH_CHECK(
135137
status >= 0,
136-
"avio_open failed: ",
138+
"avio_open failed. The destination file is ",
139+
fileName,
140+
", make sure it's a valid path? ",
137141
getFFMPEGErrorStringFromErrorCode(status));
138142

139143
initializeEncoder(sampleRate, bitRate, numChannels);
@@ -155,7 +159,9 @@ AudioEncoder::AudioEncoder(
155159
TORCH_CHECK(
156160
avFormatContext != nullptr,
157161
"Couldn't allocate AVFormatContext. ",
158-
"Check the desired extension? ",
162+
"Check the desired format? Got format=",
163+
formatName,
164+
". ",
159165
getFFMPEGErrorStringFromErrorCode(status));
160166
avFormatContext_.reset(avFormatContext);
161167

src/torchcodec/_core/ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
161161
return torch.empty([], dtype=torch.long)
162162

163163

164+
# TODO-ENCODING: rename wf to samples
164165
@register_fake("torchcodec_ns::encode_audio_to_file")
165166
def encode_audio_to_file_abstract(
166167
wf: torch.Tensor,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from ._audio_encoder import AudioEncoder # noqa
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from pathlib import Path
2+
from typing import Optional, Union
3+
4+
import torch
5+
from torch import Tensor
6+
7+
from torchcodec import _core
8+
9+
10+
class AudioEncoder:
11+
def __init__(self, samples: Tensor, *, sample_rate: int):
12+
# Some of these checks are also done in C++: it's OK, they're cheap, and
13+
# doing them here allows to surface them when the AudioEncoder is
14+
# instantiated, rather than later when the encoding methods are called.
15+
if not isinstance(samples, Tensor):
16+
raise ValueError(
17+
f"Expected samples to be a Tensor, got {type(samples) = }."
18+
)
19+
if samples.ndim != 2:
20+
raise ValueError(f"Expected 2D samples, got {samples.shape = }.")
21+
if samples.dtype != torch.float32:
22+
raise ValueError(f"Expected float32 samples, got {samples.dtype = }.")
23+
if sample_rate <= 0:
24+
raise ValueError(f"{sample_rate = } must be > 0.")
25+
26+
self._samples = samples
27+
self._sample_rate = sample_rate
28+
29+
def to_file(
30+
self,
31+
dest: Union[str, Path],
32+
*,
33+
bit_rate: Optional[int] = None,
34+
) -> None:
35+
_core.encode_audio_to_file(
36+
wf=self._samples,
37+
sample_rate=self._sample_rate,
38+
filename=dest,
39+
bit_rate=bit_rate,
40+
)
41+
42+
def to_tensor(
43+
self,
44+
format: str,
45+
*,
46+
bit_rate: Optional[int] = None,
47+
) -> Tensor:
48+
return _core.encode_audio_to_tensor(
49+
wf=self._samples,
50+
sample_rate=self._sample_rate,
51+
format=format,
52+
bit_rate=bit_rate,
53+
)

test/test_encoders.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import re
2+
3+
import pytest
4+
import torch
5+
6+
from torchcodec.encoders import AudioEncoder
7+
8+
9+
class TestAudioEncoder:
10+
11+
def test_bad_input(self):
12+
with pytest.raises(ValueError, match="Expected samples to be a Tensor"):
13+
AudioEncoder(samples=123, sample_rate=32_000)
14+
with pytest.raises(ValueError, match="Expected 2D samples"):
15+
AudioEncoder(samples=torch.rand(10), sample_rate=32_000)
16+
with pytest.raises(ValueError, match="Expected float32 samples"):
17+
AudioEncoder(
18+
samples=torch.rand(10, 10, dtype=torch.float64), sample_rate=32_000
19+
)
20+
with pytest.raises(ValueError, match="sample_rate = 0 must be > 0"):
21+
AudioEncoder(samples=torch.rand(10, 10), sample_rate=0)
22+
23+
encoder = AudioEncoder(samples=torch.rand(2, 100), sample_rate=32_000)
24+
25+
bad_path = "/bad/path.mp3"
26+
with pytest.raises(
27+
RuntimeError,
28+
match=f"avio_open failed. The destination file is {bad_path}, make sure it's a valid path",
29+
):
30+
encoder.to_file(dest=bad_path)
31+
32+
bad_extension = "output.bad_extension"
33+
with pytest.raises(RuntimeError, match="check the desired extension"):
34+
encoder.to_file(dest=bad_extension)
35+
36+
bad_format = "bad_format"
37+
with pytest.raises(
38+
RuntimeError,
39+
match=re.escape(f"Check the desired format? Got format={bad_format}"),
40+
):
41+
encoder.to_tensor(format=bad_format)

test/test_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1134,7 +1134,7 @@ def test_bad_input(self, tmp_path):
11341134
encode_audio_to_file(
11351135
wf=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3"
11361136
)
1137-
with pytest.raises(RuntimeError, match="Check the desired extension"):
1137+
with pytest.raises(RuntimeError, match="check the desired extension"):
11381138
encode_audio_to_file(
11391139
wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension"
11401140
)

0 commit comments

Comments
 (0)