Skip to content

Commit 061c60f

Browse files
committed
Address some comments
1 parent bbf2af2 commit 061c60f

File tree

9 files changed

+104
-83
lines changed

9 files changed

+104
-83
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
#include "src/torchcodec/_core/Encoder.h"
22
#include "torch/types.h"
33

4-
extern "C" {
5-
#include <libavcodec/avcodec.h>
6-
#include <libavformat/avformat.h>
7-
}
8-
94
namespace facebook::torchcodec {
105

11-
Encoder::~Encoder() {}
6+
AudioEncoder::~AudioEncoder() {}
127

138
// TODO-ENCODING: disable ffmpeg logs by default
149

15-
Encoder::Encoder(
10+
AudioEncoder::AudioEncoder(
1611
const torch::Tensor wf,
1712
int sampleRate,
1813
std::string_view fileName)
@@ -24,21 +19,21 @@ Encoder::Encoder(
2419
TORCH_CHECK(
2520
wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim());
2621
AVFormatContext* avFormatContext = nullptr;
27-
avformat_alloc_output_context2(
22+
auto status = avformat_alloc_output_context2(
2823
&avFormatContext, nullptr, nullptr, fileName.data());
2924
TORCH_CHECK(
3025
avFormatContext != nullptr,
3126
"Couldn't allocate AVFormatContext. ",
32-
"Check the desired extension?");
27+
"Check the desired extension? ",
28+
getFFMPEGErrorStringFromErrorCode(status));
3329
avFormatContext_.reset(avFormatContext);
3430

3531
// TODO-ENCODING: Should also support encoding into bytes (use
3632
// AVIOBytesContext)
3733
TORCH_CHECK(
3834
!(avFormatContext->oformat->flags & AVFMT_NOFILE),
3935
"AVFMT_NOFILE is set. We only support writing to a file.");
40-
auto status =
41-
avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
36+
status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
4237
TORCH_CHECK(
4338
status >= 0,
4439
"avio_open failed: ",
@@ -85,7 +80,10 @@ Encoder::Encoder(
8580
setDefaultChannelLayout(avCodecContext_, numChannels);
8681

8782
status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
88-
TORCH_CHECK(status == AVSUCCESS, getFFMPEGErrorStringFromErrorCode(status));
83+
TORCH_CHECK(
84+
status == AVSUCCESS,
85+
"avcodec_open2 failed: ",
86+
getFFMPEGErrorStringFromErrorCode(status));
8987

9088
TORCH_CHECK(
9189
avCodecContext_->frame_size > 0,
@@ -96,12 +94,18 @@ Encoder::Encoder(
9694
// We're allocating the stream here. Streams are meant to be freed by
9795
// avformat_free_context(avFormatContext), which we call in the
9896
// avFormatContext_'s destructor.
99-
avStream_ = avformat_new_stream(avFormatContext_.get(), nullptr);
100-
TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream.");
101-
avcodec_parameters_from_context(avStream_->codecpar, avCodecContext_.get());
97+
AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr);
98+
TORCH_CHECK(avStream != nullptr, "Couldn't create new stream.");
99+
status = avcodec_parameters_from_context(
100+
avStream->codecpar, avCodecContext_.get());
101+
TORCH_CHECK(
102+
status == AVSUCCESS,
103+
"avcodec_parameters_from_context failed: ",
104+
getFFMPEGErrorStringFromErrorCode(status));
105+
streamIndex_ = avStream->index;
102106
}
103107

104-
void Encoder::encode() {
108+
void AudioEncoder::encode() {
105109
UniqueAVFrame avFrame(av_frame_alloc());
106110
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
107111
avFrame->nb_samples = avCodecContext_->frame_size;
@@ -119,12 +123,11 @@ void Encoder::encode() {
119123
AutoAVPacket autoAVPacket;
120124

121125
uint8_t* pwf = static_cast<uint8_t*>(wf_.data_ptr());
122-
auto numSamples = wf_.sizes()[1]; // per channel
123-
auto numEncodedSamples = 0; // per channel
124-
auto numSamplesPerFrame =
125-
static_cast<long>(avCodecContext_->frame_size); // per channel
126-
auto numBytesPerSample = wf_.element_size();
127-
auto numBytesPerChannel = numSamples * numBytesPerSample;
126+
int numSamples = static_cast<int>(wf_.sizes()[1]); // per channel
127+
int numEncodedSamples = 0; // per channel
128+
int numSamplesPerFrame = avCodecContext_->frame_size; // per channel
129+
int numBytesPerSample = wf_.element_size();
130+
int numBytesPerChannel = numSamples * numBytesPerSample;
128131

129132
status = avformat_write_header(avFormatContext_.get(), nullptr);
130133
TORCH_CHECK(
@@ -139,12 +142,12 @@ void Encoder::encode() {
139142
"Couldn't make AVFrame writable: ",
140143
getFFMPEGErrorStringFromErrorCode(status));
141144

142-
auto numSamplesToEncode = std::min(
143-
numSamplesPerFrame, static_cast<long>(numSamples - numEncodedSamples));
144-
auto numBytesToEncode = numSamplesToEncode * numBytesPerSample;
145+
int numSamplesToEncode =
146+
std::min(numSamplesPerFrame, numSamples - numEncodedSamples);
147+
int numBytesToEncode = numSamplesToEncode * numBytesPerSample;
145148

146149
for (int ch = 0; ch < wf_.sizes()[0]; ch++) {
147-
memcpy(
150+
std::memcpy(
148151
avFrame->data[ch], pwf + ch * numBytesPerChannel, numBytesToEncode);
149152
}
150153
pwf += numBytesToEncode;
@@ -155,14 +158,14 @@ void Encoder::encode() {
155158
// encoded frame would contain more samples than necessary and our results
156159
// wouldn't match the ffmpeg CLI.
157160
avFrame->nb_samples = numSamplesToEncode;
158-
encode_inner_loop(autoAVPacket, avFrame);
161+
encodeInnerLoop(autoAVPacket, avFrame);
159162

160-
avFrame->pts += numSamplesToEncode;
163+
avFrame->pts += static_cast<int64_t>(numSamplesToEncode);
161164
numEncodedSamples += numSamplesToEncode;
162165
}
163166
TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong.");
164167

165-
encode_inner_loop(autoAVPacket, UniqueAVFrame(nullptr)); // flush
168+
flushBuffers();
166169

167170
status = av_write_trailer(avFormatContext_.get());
168171
TORCH_CHECK(
@@ -171,7 +174,7 @@ void Encoder::encode() {
171174
getFFMPEGErrorStringFromErrorCode(status));
172175
}
173176

174-
void Encoder::encode_inner_loop(
177+
void AudioEncoder::encodeInnerLoop(
175178
AutoAVPacket& autoAVPacket,
176179
const UniqueAVFrame& avFrame) {
177180
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
@@ -199,10 +202,7 @@ void Encoder::encode_inner_loop(
199202
"Error receiving packet: ",
200203
getFFMPEGErrorStringFromErrorCode(status));
201204

202-
// TODO-ENCODING why are these 2 lines needed??
203-
av_packet_rescale_ts(
204-
packet.get(), avCodecContext_->time_base, avStream_->time_base);
205-
packet->stream_index = avStream_->index;
205+
packet->stream_index = streamIndex_;
206206

207207
status = av_interleaved_write_frame(avFormatContext_.get(), packet.get());
208208
TORCH_CHECK(
@@ -211,4 +211,9 @@ void Encoder::encode_inner_loop(
211211
getFFMPEGErrorStringFromErrorCode(status));
212212
}
213213
}
214+
215+
void AudioEncoder::flushBuffers() {
216+
AutoAVPacket autoAVPacket;
217+
encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr));
218+
}
214219
} // namespace facebook::torchcodec

src/torchcodec/_core/Encoder.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,25 @@
33
#include "src/torchcodec/_core/FFMPEGCommon.h"
44

55
namespace facebook::torchcodec {
6-
class Encoder {
6+
class AudioEncoder {
77
public:
8-
~Encoder();
8+
~AudioEncoder();
99

10-
Encoder(const torch::Tensor wf, int sampleRate, std::string_view fileName);
10+
AudioEncoder(
11+
const torch::Tensor wf,
12+
int sampleRate,
13+
std::string_view fileName);
1114
void encode();
1215

1316
private:
14-
void encode_inner_loop(
17+
void encodeInnerLoop(
1518
AutoAVPacket& autoAVPacket,
1619
const UniqueAVFrame& avFrame);
20+
void flushBuffers();
1721

18-
UniqueAVFormatContextForEncoding avFormatContext_;
22+
UniqueEncodingAVFormatContext avFormatContext_;
1923
UniqueAVCodecContext avCodecContext_;
20-
AVStream* avStream_;
24+
int streamIndex_;
2125

2226
const torch::Tensor wf_;
2327
// The *output* sample rate. We can't really decide for the user what it

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ struct Deleter {
5050
};
5151

5252
// Unique pointers for FFMPEG structures.
53-
using UniqueAVFormatContextForDecoding = std::unique_ptr<
53+
using UniqueDecodingAVFormatContext = std::unique_ptr<
5454
AVFormatContext,
5555
Deleterp<AVFormatContext, void, avformat_close_input>>;
56-
using UniqueAVFormatContextForEncoding = std::unique_ptr<
56+
using UniqueEncodingAVFormatContext = std::unique_ptr<
5757
AVFormatContext,
5858
Deleter<AVFormatContext, void, avformat_free_context>>;
5959
using UniqueAVCodecContext = std::unique_ptr<

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1443,7 +1443,7 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
14431443
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
14441444
for (auto channel = 0; channel < numChannels;
14451445
++channel, outputChannelData += numBytesPerChannel) {
1446-
memcpy(
1446+
std::memcpy(
14471447
outputChannelData,
14481448
avFrame->extended_data[channel],
14491449
numBytesPerChannel);

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ class SingleStreamDecoder {
492492

493493
SeekMode seekMode_;
494494
ContainerMetadata containerMetadata_;
495-
UniqueAVFormatContextForDecoding formatContext_;
495+
UniqueDecodingAVFormatContext formatContext_;
496496
std::map<int, StreamInfo> streamInfos_;
497497
const int NO_ACTIVE_STREAM = -2;
498498
int activeStreamIndex_ = NO_ACTIVE_STREAM;

src/torchcodec/_core/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
_test_frame_pts_equality,
1919
add_audio_stream,
2020
add_video_stream,
21-
create_encoder,
21+
create_audio_encoder,
2222
create_from_bytes,
2323
create_from_file,
2424
create_from_file_like,
2525
create_from_tensor,
26-
encode,
26+
encode_audio,
2727
get_ffmpeg_library_versions,
2828
get_frame_at_index,
2929
get_frame_at_pts,

src/torchcodec/_core/custom_ops.cpp

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
2828
m.impl_abstract_pystub(
2929
"torchcodec._core.ops", "//pytorch/torchcodec:torchcodec");
3030
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
31-
m.def("create_encoder(Tensor wf, int sample_rate, str filename) -> Tensor");
32-
m.def("encode(Tensor(a!) encoder) -> ()");
31+
m.def(
32+
"create_audio_encoder(Tensor wf, int sample_rate, str filename) -> Tensor");
33+
m.def("encode_audio(Tensor(a!) encoder) -> ()");
3334
m.def(
3435
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
3536
m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
@@ -384,35 +385,42 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
384385
return makeOpsAudioFramesOutput(result);
385386
}
386387

387-
at::Tensor wrapEncoderPointerToTensor(std::unique_ptr<Encoder> uniqueEncoder) {
388-
Encoder* encoder = uniqueEncoder.release();
388+
at::Tensor wrapAudioEncoderPointerToTensor(
389+
std::unique_ptr<AudioEncoder> uniqueAudioEncoder) {
390+
AudioEncoder* encoder = uniqueAudioEncoder.release();
389391

390392
auto deleter = [encoder](void*) { delete encoder; };
391393
at::Tensor tensor =
392-
at::from_blob(encoder, {sizeof(Encoder)}, deleter, {at::kLong});
393-
auto encoder_ = static_cast<Encoder*>(tensor.mutable_data_ptr());
394-
TORCH_CHECK_EQ(encoder_, encoder) << "Encoder=" << encoder_;
394+
at::from_blob(encoder, {sizeof(AudioEncoder*)}, deleter, {at::kLong});
395+
auto encoder_ = static_cast<AudioEncoder*>(tensor.mutable_data_ptr());
396+
TORCH_CHECK_EQ(encoder_, encoder) << "AudioEncoder=" << encoder_;
395397
return tensor;
396398
}
397399

398-
Encoder* unwrapTensorToGetEncoder(at::Tensor& tensor) {
400+
AudioEncoder* unwrapTensorToGetAudioEncoder(at::Tensor& tensor) {
399401
TORCH_INTERNAL_ASSERT(tensor.is_contiguous());
400402
void* buffer = tensor.mutable_data_ptr();
401-
Encoder* encoder = static_cast<Encoder*>(buffer);
403+
AudioEncoder* encoder = static_cast<AudioEncoder*>(buffer);
402404
return encoder;
403405
}
404406

405-
at::Tensor create_encoder(
407+
at::Tensor create_audio_encoder(
406408
const at::Tensor wf,
407409
int64_t sample_rate,
408410
std::string_view file_name) {
409-
std::unique_ptr<Encoder> uniqueEncoder =
410-
std::make_unique<Encoder>(wf, static_cast<int>(sample_rate), file_name);
411-
return wrapEncoderPointerToTensor(std::move(uniqueEncoder));
412-
}
413-
414-
void encode(at::Tensor& encoder) {
415-
auto encoder_ = unwrapTensorToGetEncoder(encoder);
411+
TORCH_CHECK(
412+
sample_rate <= std::numeric_limits<int>::max(),
413+
"sample_rate=",
414+
sample_rate,
415+
" is too large to be cast to an int.");
416+
std::unique_ptr<AudioEncoder> uniqueAudioEncoder =
417+
std::make_unique<AudioEncoder>(
418+
wf, static_cast<int>(sample_rate), file_name);
419+
return wrapAudioEncoderPointerToTensor(std::move(uniqueAudioEncoder));
420+
}
421+
422+
void encode_audio(at::Tensor& encoder) {
423+
auto encoder_ = unwrapTensorToGetAudioEncoder(encoder);
416424
encoder_->encode();
417425
}
418426

@@ -650,15 +658,15 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) {
650658

651659
TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
652660
m.impl("create_from_file", &create_from_file);
653-
m.impl("create_encoder", &create_encoder);
661+
m.impl("create_audio_encoder", &create_audio_encoder);
654662
m.impl("create_from_tensor", &create_from_tensor);
655663
m.impl("_convert_to_tensor", &_convert_to_tensor);
656664
m.impl(
657665
"_get_json_ffmpeg_library_versions", &_get_json_ffmpeg_library_versions);
658666
}
659667

660668
TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
661-
m.impl("encode", &encode);
669+
m.impl("encode_audio", &encode_audio);
662670
m.impl("seek_to_pts", &seek_to_pts);
663671
m.impl("add_video_stream", &add_video_stream);
664672
m.impl("_add_video_stream", &_add_video_stream);

src/torchcodec/_core/ops.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,12 @@ def load_torchcodec_shared_libraries():
9191
create_from_file = torch._dynamo.disallow_in_graph(
9292
torch.ops.torchcodec_ns.create_from_file.default
9393
)
94-
create_encoder = torch._dynamo.disallow_in_graph(
95-
torch.ops.torchcodec_ns.create_encoder.default
94+
create_audio_encoder = torch._dynamo.disallow_in_graph(
95+
torch.ops.torchcodec_ns.create_audio_encoder.default
96+
)
97+
encode_audio = torch._dynamo.disallow_in_graph(
98+
torch.ops.torchcodec_ns.encode_audio.default
9699
)
97-
encode = torch._dynamo.disallow_in_graph(torch.ops.torchcodec_ns.encode.default)
98100
create_from_tensor = torch._dynamo.disallow_in_graph(
99101
torch.ops.torchcodec_ns.create_from_tensor.default
100102
)
@@ -159,15 +161,15 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
159161
return torch.empty([], dtype=torch.long)
160162

161163

162-
@register_fake("torchcodec_ns::create_encoder")
163-
def create_encoder_abstract(
164+
@register_fake("torchcodec_ns::create_audio_encoder")
165+
def create_audio_encoder_abstract(
164166
wf: torch.Tensor, sample_rate: int, filename: str
165167
) -> torch.Tensor:
166168
return torch.empty([], dtype=torch.long)
167169

168170

169-
@register_fake("torchcodec_ns::encode")
170-
def encode_abstract(encoder: torch.Tensor) -> torch.Tensor:
171+
@register_fake("torchcodec_ns::encode_audio")
172+
def encode_audio_abstract(encoder: torch.Tensor) -> torch.Tensor:
171173
return torch.empty([], dtype=torch.long)
172174

173175

0 commit comments

Comments
 (0)