Skip to content

Commit 3399b34

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into encoding_sample_rate_lezzzgo
2 parents 8fdb6ed + 3056f40 commit 3399b34

File tree

6 files changed

+49
-43
lines changed

6 files changed

+49
-43
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,19 @@ namespace facebook::torchcodec {
88

99
namespace {
1010

11-
torch::Tensor validateWf(torch::Tensor wf) {
11+
torch::Tensor validateSamples(torch::Tensor samples) {
1212
TORCH_CHECK(
13-
wf.dtype() == torch::kFloat32,
14-
"waveform must have float32 dtype, got ",
15-
wf.dtype());
16-
TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim());
13+
samples.dtype() == torch::kFloat32,
14+
"samples must have float32 dtype, got ",
15+
samples.dtype());
16+
TORCH_CHECK(
17+
samples.dim() == 2,
18+
"samples must have 2 dimensions, got ",
19+
samples.dim());
1720

1821
// We enforce this, but if we get user reports we should investigate whether
1922
// that's actually needed.
20-
int numChannels = static_cast<int>(wf.sizes()[0]);
23+
int numChannels = static_cast<int>(samples.sizes()[0]);
2124
TORCH_CHECK(
2225
numChannels <= AV_NUM_DATA_POINTERS,
2326
"Trying to encode ",
@@ -26,7 +29,7 @@ torch::Tensor validateWf(torch::Tensor wf) {
2629
AV_NUM_DATA_POINTERS,
2730
" channels per frame.");
2831

29-
return wf.contiguous();
32+
return samples.contiguous();
3033
}
3134

3235
void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
@@ -71,7 +74,7 @@ static const std::vector<AVSampleFormat> preferredFormatsOrder = {
7174

7275
AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
7376
// Find a sample format that the encoder supports. We prefer using FLT[P],
74-
// since this is the format of the input waveform. If FLTP isn't supported
77+
// since this is the format of the input samples. If FLTP isn't supported
7578
// then we'll need to convert the AVFrame's format. Our heuristic is to encode
7679
// into the format with the highest resolution.
7780
if (avCodec.sample_fmts == nullptr) {
@@ -115,11 +118,11 @@ UniqueAVFrame allocateAVFrame(int numSamples, int sampleRate, int numChannels) {
115118
AudioEncoder::~AudioEncoder() {}
116119

117120
AudioEncoder::AudioEncoder(
118-
const torch::Tensor wf,
121+
const torch::Tensor samples,
119122
int sampleRate,
120123
std::string_view fileName,
121124
const AudioStreamOptions& audioStreamOptions)
122-
: wf_(validateWf(wf)), sampleRateInput_(sampleRate) {
125+
: samples_(validateSamples(samples)), sampleRateInput_(sampleRate) {
123126
setFFmpegLogLevel();
124127
AVFormatContext* avFormatContext = nullptr;
125128
int status = avformat_alloc_output_context2(
@@ -146,12 +149,12 @@ AudioEncoder::AudioEncoder(
146149
}
147150

148151
AudioEncoder::AudioEncoder(
149-
const torch::Tensor wf,
152+
const torch::Tensor samples,
150153
int sampleRate,
151154
std::string_view formatName,
152155
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
153156
const AudioStreamOptions& audioStreamOptions)
154-
: wf_(validateWf(wf)),
157+
: samples_(validateSamples(samples)),
155158
sampleRateInput_(sampleRate),
156159
avioContextHolder_(std::move(avioContextHolder)) {
157160
setFFmpegLogLevel();
@@ -194,8 +197,8 @@ void AudioEncoder::initializeEncoder(
194197
// well when "-b:a" isn't specified.
195198
avCodecContext_->bit_rate = desiredBitRate.value_or(0);
196199

197-
outNumChannels_ =
198-
static_cast<int>(audioStreamOptions.numChannels.value_or(wf_.sizes()[0]));
200+
outNumChannels_ = static_cast<int>(
201+
audioStreamOptions.numChannels.value_or(samples_.sizes()[0]));
199202
validateNumChannels(*avCodec, outNumChannels_);
200203
// The avCodecContext layout defines the layout of the encoded output, it's
201204
// not related to the input sampes.
@@ -205,9 +208,9 @@ void AudioEncoder::initializeEncoder(
205208
validateSampleRate(*avCodec, outSampleRate_);
206209
avCodecContext_->sample_rate = outSampleRate_;
207210

208-
// Input waveform is expected to be FLTP. Not all encoders support FLTP, so we
209-
// may need to convert the wf into a supported output sample format, which is
210-
// what the `.sample_fmt` defines.
211+
// Input samples are expected to be FLTP. Not all encoders support FLTP, so we
212+
// may need to convert the samples into a supported output sample format,
213+
// which is what the `.sample_fmt` defines.
211214
avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec);
212215

213216
int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
@@ -265,15 +268,15 @@ void AudioEncoder::encode() {
265268
UniqueAVFrame avFrame = allocateAVFrame(
266269
numSamplesAllocatedPerFrame,
267270
sampleRateInput_,
268-
static_cast<int>(wf_.sizes()[0]));
271+
static_cast<int>(samples_.sizes()[0]));
269272
avFrame->pts = 0;
270273

271274
AutoAVPacket autoAVPacket;
272275

273-
uint8_t* pwf = static_cast<uint8_t*>(wf_.data_ptr());
274-
int numSamples = static_cast<int>(wf_.sizes()[1]); // per channel
276+
uint8_t* psamples = static_cast<uint8_t*>(samples_.data_ptr());
277+
int numSamples = static_cast<int>(samples_.sizes()[1]); // per channel
275278
int numEncodedSamples = 0; // per channel
276-
int numBytesPerSample = static_cast<int>(wf_.element_size());
279+
int numBytesPerSample = static_cast<int>(samples_.element_size());
277280
int numBytesPerChannel = numSamples * numBytesPerSample;
278281

279282
auto status = avformat_write_header(avFormatContext_.get(), nullptr);
@@ -293,11 +296,13 @@ void AudioEncoder::encode() {
293296
std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples);
294297
int numBytesToEncode = numSamplesToEncode * numBytesPerSample;
295298

296-
for (int ch = 0; ch < wf_.sizes()[0]; ch++) {
299+
for (int ch = 0; ch < samples_.sizes()[0]; ch++) {
297300
std::memcpy(
298-
avFrame->data[ch], pwf + ch * numBytesPerChannel, numBytesToEncode);
301+
avFrame->data[ch],
302+
psamples + ch * numBytesPerChannel,
303+
numBytesToEncode);
299304
}
300-
pwf += numBytesToEncode;
305+
psamples += numBytesToEncode;
301306

302307
// Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size so
303308
// that the frame buffers are allocated to a big enough size. Here, we reset

src/torchcodec/_core/Encoder.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,18 @@ class AudioEncoder {
1515
// Passing 44_100 could result in output being 44000 if only 44000 is
1616
// supported.
1717
AudioEncoder(
18-
const torch::Tensor wf,
18+
const torch::Tensor samples,
1919
// TODO-ENCODING: update this comment when we support an output sample
2020
// rate. This will become the input sample rate.
2121
// The *output* sample rate. We can't really decide for the user what it
22-
// should be. Particularly, the sample rate of the input waveform should
22+
// should be. Particularly, the sample rate of the input samples should
2323
// match this, and that's up to the user. If sample rates don't match,
2424
// encoding will still work but audio will be distorted.
2525
int sampleRate,
2626
std::string_view fileName,
2727
const AudioStreamOptions& audioStreamOptions);
2828
AudioEncoder(
29-
const torch::Tensor wf,
29+
const torch::Tensor samples,
3030
int sampleRate,
3131
std::string_view formatName,
3232
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
@@ -52,7 +52,7 @@ class AudioEncoder {
5252
int outNumChannels_ = -1;
5353
int outSampleRate_ = -1;
5454

55-
const torch::Tensor wf_;
55+
const torch::Tensor samples_;
5656
int sampleRateInput_ = -1;
5757

5858
UniqueAVAudioFifo avAudioFifo_;

src/torchcodec/_core/custom_ops.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
2929
"torchcodec._core.ops", "//pytorch/torchcodec:torchcodec");
3030
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
3131
m.def(
32-
"encode_audio_to_file(Tensor wf, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
32+
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3333
m.def(
34-
"encode_audio_to_tensor(Tensor wf, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
34+
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
3535
m.def(
3636
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
3737
m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
@@ -388,7 +388,7 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
388388
}
389389

390390
void encode_audio_to_file(
391-
const at::Tensor wf,
391+
const at::Tensor samples,
392392
int64_t sample_rate,
393393
std::string_view file_name,
394394
std::optional<int64_t> bit_rate = std::nullopt,
@@ -401,12 +401,12 @@ void encode_audio_to_file(
401401
audioStreamOptions.numChannels = num_channels;
402402
audioStreamOptions.sampleRate = desired_sample_rate;
403403
AudioEncoder(
404-
wf, validateSampleRate(sample_rate), file_name, audioStreamOptions)
404+
samples, validateSampleRate(sample_rate), file_name, audioStreamOptions)
405405
.encode();
406406
}
407407

408408
at::Tensor encode_audio_to_tensor(
409-
const at::Tensor wf,
409+
const at::Tensor samples,
410410
int64_t sample_rate,
411411
std::string_view format,
412412
std::optional<int64_t> bit_rate = std::nullopt,
@@ -420,7 +420,7 @@ at::Tensor encode_audio_to_tensor(
420420
audioStreamOptions.numChannels = num_channels;
421421
audioStreamOptions.sampleRate = desired_sample_rate;
422422
return AudioEncoder(
423-
wf,
423+
samples,
424424
validateSampleRate(sample_rate),
425425
format,
426426
std::move(avioContextHolder),

src/torchcodec/_core/ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,9 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
161161
return torch.empty([], dtype=torch.long)
162162

163163

164-
# TODO-ENCODING: rename wf to samples
165164
@register_fake("torchcodec_ns::encode_audio_to_file")
166165
def encode_audio_to_file_abstract(
167-
wf: torch.Tensor,
166+
samples: torch.Tensor,
168167
sample_rate: int,
169168
filename: str,
170169
bit_rate: Optional[int] = None,
@@ -176,7 +175,7 @@ def encode_audio_to_file_abstract(
176175

177176
@register_fake("torchcodec_ns::encode_audio_to_tensor")
178177
def encode_audio_to_tensor_abstract(
179-
wf: torch.Tensor,
178+
samples: torch.Tensor,
180179
sample_rate: int,
181180
format: str,
182181
bit_rate: Optional[int] = None,

src/torchcodec/encoders/_audio_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def to_file(
3535
sample_rate: Optional[int] = None,
3636
) -> None:
3737
_core.encode_audio_to_file(
38-
wf=self._samples,
38+
samples=self._samples,
3939
sample_rate=self._sample_rate,
4040
filename=dest,
4141
bit_rate=bit_rate,
@@ -52,7 +52,7 @@ def to_tensor(
5252
sample_rate: Optional[int] = None,
5353
) -> Tensor:
5454
return _core.encode_audio_to_tensor(
55-
wf=self._samples,
55+
samples=self._samples,
5656
sample_rate=self._sample_rate,
5757
format=format,
5858
bit_rate=bit_rate,

test/test_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,22 +1101,24 @@ def test_bad_input(self, tmp_path):
11011101

11021102
with pytest.raises(RuntimeError, match="must have float32 dtype, got int"):
11031103
encode_audio_to_file(
1104-
wf=torch.arange(10, dtype=torch.int),
1104+
samples=torch.arange(10, dtype=torch.int),
11051105
sample_rate=10,
11061106
filename=valid_output_file,
11071107
)
11081108
with pytest.raises(RuntimeError, match="must have 2 dimensions, got 1"):
11091109
encode_audio_to_file(
1110-
wf=torch.rand(3), sample_rate=10, filename=valid_output_file
1110+
samples=torch.rand(3), sample_rate=10, filename=valid_output_file
11111111
)
11121112

11131113
with pytest.raises(RuntimeError, match="No such file or directory"):
11141114
encode_audio_to_file(
1115-
wf=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3"
1115+
samples=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3"
11161116
)
11171117
with pytest.raises(RuntimeError, match="check the desired extension"):
11181118
encode_audio_to_file(
1119-
wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension"
1119+
samples=torch.rand(2, 10),
1120+
sample_rate=10,
1121+
filename="./file.bad_extension",
11201122
)
11211123

11221124

0 commit comments

Comments
 (0)