Skip to content

Commit eee8889

Browse files
committed
Merge branch 'main' of https://github.com/meta-pytorch/torchcodec into encode_gpu
2 parents 926b7ea + cac99ae commit eee8889

File tree

6 files changed

+106
-58
lines changed

6 files changed

+106
-58
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ VideoEncoder::~VideoEncoder() {
674674

675675
VideoEncoder::VideoEncoder(
676676
const torch::Tensor& frames,
677-
int frameRate,
677+
double frameRate,
678678
std::string_view fileName,
679679
const VideoStreamOptions& videoStreamOptions)
680680
: frames_(validateFrames(frames, videoStreamOptions.device)),
@@ -707,7 +707,7 @@ VideoEncoder::VideoEncoder(
707707

708708
VideoEncoder::VideoEncoder(
709709
const torch::Tensor& frames,
710-
int frameRate,
710+
double frameRate,
711711
std::string_view formatName,
712712
std::unique_ptr<AVIOContextHolder> avioContextHolder,
713713
const VideoStreamOptions& videoStreamOptions)
@@ -812,9 +812,9 @@ void VideoEncoder::initializeEncoder(
812812
avCodecContext_->width = outWidth_;
813813
avCodecContext_->height = outHeight_;
814814
avCodecContext_->pix_fmt = outPixelFormat_;
815-
// TODO-VideoEncoder: Verify that frame_rate and time_base are correct
816-
avCodecContext_->time_base = {1, inFrameRate_};
817-
avCodecContext_->framerate = {inFrameRate_, 1};
815+
// TODO-VideoEncoder: Add and utilize output frame_rate option
816+
avCodecContext_->framerate = av_d2q(inFrameRate_, INT_MAX);
817+
avCodecContext_->time_base = av_inv_q(avCodecContext_->framerate);
818818

819819
// Set flag for containers that require extradata to be in the codec context
820820
if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) {
@@ -866,6 +866,10 @@ void VideoEncoder::initializeEncoder(
866866

867867
// Set the stream time base to encode correct frame timestamps
868868
avStream_->time_base = avCodecContext_->time_base;
869+
// Set the stream frame rate to store correct frame durations for some
870+
// containers (webm, mkv)
871+
avStream_->r_frame_rate = avCodecContext_->framerate;
872+
869873
status = avcodec_parameters_from_context(
870874
avStream_->codecpar, avCodecContext_.get());
871875
TORCH_CHECK(

src/torchcodec/_core/Encoder.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,13 @@ class VideoEncoder {
144144

145145
VideoEncoder(
146146
const torch::Tensor& frames,
147-
int frameRate,
147+
double frameRate,
148148
std::string_view fileName,
149149
const VideoStreamOptions& videoStreamOptions);
150150

151151
VideoEncoder(
152152
const torch::Tensor& frames,
153-
int frameRate,
153+
double frameRate,
154154
std::string_view formatName,
155155
std::unique_ptr<AVIOContextHolder> avioContextHolder,
156156
const VideoStreamOptions& videoStreamOptions);
@@ -170,7 +170,7 @@ class VideoEncoder {
170170
UniqueSwsContext swsContext_;
171171

172172
const torch::Tensor frames_;
173-
int inFrameRate_;
173+
double inFrameRate_;
174174

175175
int inWidth_ = -1;
176176
int inHeight_ = -1;

src/torchcodec/_core/custom_ops.cpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3737
m.def(
3838
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3939
m.def(
40-
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str device=\"cpu\", str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
40+
"encode_video_to_file(Tensor frames, float frame_rate, str filename, str device=\"cpu\", str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
4141
m.def(
42-
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str device=\"cpu\", str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor");
42+
"encode_video_to_tensor(Tensor frames, float frame_rate, str format, str device=\"cpu\", str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor");
4343
m.def(
44-
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str device=\"cpu\",str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
44+
"_encode_video_to_file_like(Tensor frames, float frame_rate, str format, int file_like_context, str device=\"cpu\",str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
4545
m.def(
4646
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4747
m.def(
@@ -638,7 +638,7 @@ void _encode_audio_to_file_like(
638638

639639
void encode_video_to_file(
640640
const at::Tensor& frames,
641-
int64_t frame_rate,
641+
double frame_rate,
642642
std::string_view file_name,
643643
std::string_view device = "cpu",
644644
std::optional<std::string_view> codec = std::nullopt,
@@ -658,17 +658,12 @@ void encode_video_to_file(
658658
unflattenExtraOptions(extra_options.value());
659659
}
660660

661-
VideoEncoder(
662-
frames,
663-
validateInt64ToInt(frame_rate, "frame_rate"),
664-
file_name,
665-
videoStreamOptions)
666-
.encode();
661+
VideoEncoder(frames, frame_rate, file_name, videoStreamOptions).encode();
667662
}
668663

669664
at::Tensor encode_video_to_tensor(
670665
const at::Tensor& frames,
671-
int64_t frame_rate,
666+
double frame_rate,
672667
std::string_view format,
673668
std::string_view device = "cpu",
674669
std::optional<std::string_view> codec = std::nullopt,
@@ -691,7 +686,7 @@ at::Tensor encode_video_to_tensor(
691686

692687
return VideoEncoder(
693688
frames,
694-
validateInt64ToInt(frame_rate, "frame_rate"),
689+
frame_rate,
695690
format,
696691
std::move(avioContextHolder),
697692
videoStreamOptions)
@@ -700,7 +695,7 @@ at::Tensor encode_video_to_tensor(
700695

701696
void _encode_video_to_file_like(
702697
const at::Tensor& frames,
703-
int64_t frame_rate,
698+
double frame_rate,
704699
std::string_view format,
705700
int64_t file_like_context,
706701
std::string_view device = "cpu",
@@ -729,7 +724,7 @@ void _encode_video_to_file_like(
729724

730725
VideoEncoder encoder(
731726
frames,
732-
validateInt64ToInt(frame_rate, "frame_rate"),
727+
frame_rate,
733728
format,
734729
std::move(avioContextHolder),
735730
videoStreamOptions);

src/torchcodec/_core/ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def encode_audio_to_file_like(
210210

211211
def encode_video_to_file_like(
212212
frames: torch.Tensor,
213-
frame_rate: int,
213+
frame_rate: float,
214214
format: str,
215215
file_like: Union[io.RawIOBase, io.BufferedIOBase],
216216
device: Optional[str] = "cpu",
@@ -332,7 +332,7 @@ def _encode_audio_to_file_like_abstract(
332332
@register_fake("torchcodec_ns::encode_video_to_file")
333333
def encode_video_to_file_abstract(
334334
frames: torch.Tensor,
335-
frame_rate: int,
335+
frame_rate: float,
336336
filename: str,
337337
device: str = "cpu",
338338
codec: Optional[str] = None,
@@ -347,7 +347,7 @@ def encode_video_to_file_abstract(
347347
@register_fake("torchcodec_ns::encode_video_to_tensor")
348348
def encode_video_to_tensor_abstract(
349349
frames: torch.Tensor,
350-
frame_rate: int,
350+
frame_rate: float,
351351
format: str,
352352
device: str = "cpu",
353353
codec: Optional[str] = None,
@@ -362,7 +362,7 @@ def encode_video_to_tensor_abstract(
362362
@register_fake("torchcodec_ns::_encode_video_to_file_like")
363363
def _encode_video_to_file_like_abstract(
364364
frames: torch.Tensor,
365-
frame_rate: int,
365+
frame_rate: float,
366366
format: str,
367367
file_like_context: int,
368368
device: str = "cpu",

src/torchcodec/encoders/_video_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class VideoEncoder:
1515
tensor of shape ``(N, C, H, W)`` where N is the number of frames,
1616
C is 3 channels (RGB), H is height, and W is width.
1717
Values must be uint8 in the range ``[0, 255]``.
18-
frame_rate (int): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate.
18+
frame_rate (float): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate.
1919
device (str or torch.device, optional): The device to use for encoding. Default: "cpu".
2020
If you pass a CUDA device, frames will be encoded on GPU.
2121
Note: The "beta" CUDA backend is not supported for encoding.
@@ -25,7 +25,7 @@ def __init__(
2525
self,
2626
frames: Tensor,
2727
*,
28-
frame_rate: int,
28+
frame_rate: float,
2929
device: Optional[Union[str, torch_device]] = "cpu",
3030
):
3131
torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder")

0 commit comments

Comments
 (0)