1010#include < string>
1111#include " c10/core/SymIntArrayRef.h"
1212#include " c10/util/Exception.h"
13+ #include " src/torchcodec/_core/AVIOFileLikeContext.h"
1314#include " src/torchcodec/_core/AVIOTensorContext.h"
1415#include " src/torchcodec/_core/Encoder.h"
1516#include " src/torchcodec/_core/SingleStreamDecoder.h"
@@ -35,9 +36,12 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3536 " encode_video_to_file(Tensor frames, int frame_rate, str filename) -> ()" );
3637 m.def (
3738 " encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor" );
39+ m.def (
40+ " _encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()" );
3841 m.def (
3942 " create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor" );
40- m.def (" _convert_to_tensor(int decoder_ptr) -> Tensor" );
43+ m.def (
44+ " _create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor" );
4145 m.def (
4246 " _add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()" );
4347 m.def (
@@ -167,6 +171,18 @@ std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {
167171 return ss.str ();
168172}
169173
174+ SingleStreamDecoder::SeekMode seekModeFromString (std::string_view seekMode) {
175+ if (seekMode == " exact" ) {
176+ return SingleStreamDecoder::SeekMode::exact;
177+ } else if (seekMode == " approximate" ) {
178+ return SingleStreamDecoder::SeekMode::approximate;
179+ } else if (seekMode == " custom_frame_mappings" ) {
180+ return SingleStreamDecoder::SeekMode::custom_frame_mappings;
181+ } else {
182+ TORCH_CHECK (false , " Invalid seek mode: " + std::string (seekMode));
183+ }
184+ }
185+
170186} // namespace
171187
172188// ==============================
@@ -205,16 +221,32 @@ at::Tensor create_from_tensor(
205221 realSeek = seekModeFromString (seek_mode.value ());
206222 }
207223
208- auto contextHolder = std::make_unique<AVIOFromTensorContext>(video_tensor);
224+ auto avioContextHolder =
225+ std::make_unique<AVIOFromTensorContext>(video_tensor);
209226
210227 std::unique_ptr<SingleStreamDecoder> uniqueDecoder =
211- std::make_unique<SingleStreamDecoder>(std::move (contextHolder), realSeek);
228+ std::make_unique<SingleStreamDecoder>(
229+ std::move (avioContextHolder), realSeek);
212230 return wrapDecoderPointerToTensor (std::move (uniqueDecoder));
213231}
214232
215- at::Tensor _convert_to_tensor (int64_t decoder_ptr) {
216- auto decoder = reinterpret_cast <SingleStreamDecoder*>(decoder_ptr);
217- std::unique_ptr<SingleStreamDecoder> uniqueDecoder (decoder);
233+ at::Tensor _create_from_file_like (
234+ int64_t file_like_context,
235+ std::optional<std::string_view> seek_mode) {
236+ auto fileLikeContext =
237+ reinterpret_cast <AVIOFileLikeContext*>(file_like_context);
238+ TORCH_CHECK (
239+ fileLikeContext != nullptr , " file_like_context must be a valid pointer" );
240+ std::unique_ptr<AVIOFileLikeContext> avioContextHolder (fileLikeContext);
241+
242+ SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact;
243+ if (seek_mode.has_value ()) {
244+ realSeek = seekModeFromString (seek_mode.value ());
245+ }
246+
247+ std::unique_ptr<SingleStreamDecoder> uniqueDecoder =
248+ std::make_unique<SingleStreamDecoder>(
249+ std::move (avioContextHolder), realSeek);
218250 return wrapDecoderPointerToTensor (std::move (uniqueDecoder));
219251}
220252
@@ -456,6 +488,36 @@ at::Tensor encode_audio_to_tensor(
456488 .encodeToTensor ();
457489}
458490
491+ void _encode_audio_to_file_like (
492+ const at::Tensor& samples,
493+ int64_t sample_rate,
494+ std::string_view format,
495+ int64_t file_like_context,
496+ std::optional<int64_t > bit_rate = std::nullopt ,
497+ std::optional<int64_t > num_channels = std::nullopt ,
498+ std::optional<int64_t > desired_sample_rate = std::nullopt ) {
499+ auto fileLikeContext =
500+ reinterpret_cast <AVIOFileLikeContext*>(file_like_context);
501+ TORCH_CHECK (
502+ fileLikeContext != nullptr , " file_like_context must be a valid pointer" );
503+ std::unique_ptr<AVIOFileLikeContext> avioContextHolder (fileLikeContext);
504+
505+ AudioStreamOptions audioStreamOptions;
506+ audioStreamOptions.bitRate = validateOptionalInt64ToInt (bit_rate, " bit_rate" );
507+ audioStreamOptions.numChannels =
508+ validateOptionalInt64ToInt (num_channels, " num_channels" );
509+ audioStreamOptions.sampleRate =
510+ validateOptionalInt64ToInt (desired_sample_rate, " desired_sample_rate" );
511+
512+ AudioEncoder encoder (
513+ samples,
514+ validateInt64ToInt (sample_rate, " sample_rate" ),
515+ format,
516+ std::move (avioContextHolder),
517+ audioStreamOptions);
518+ encoder.encode ();
519+ }
520+
459521// For testing only. We need to implement this operation as a core library
460522// function because what we're testing is round-tripping pts values as
461523// double-precision floating point numbers from C++ to Python and back to C++.
@@ -709,7 +771,7 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) {
709771TORCH_LIBRARY_IMPL (torchcodec_ns, BackendSelect, m) {
710772 m.impl (" create_from_file" , &create_from_file);
711773 m.impl (" create_from_tensor" , &create_from_tensor);
712- m.impl (" _convert_to_tensor " , &_convert_to_tensor );
774+ m.impl (" _create_from_file_like " , &_create_from_file_like );
713775 m.impl (
714776 " _get_json_ffmpeg_library_versions" , &_get_json_ffmpeg_library_versions);
715777}
@@ -718,6 +780,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
718780 m.impl (" encode_audio_to_file" , &encode_audio_to_file);
719781 m.impl (" encode_video_to_file" , &encode_video_to_file);
720782 m.impl (" encode_audio_to_tensor" , &encode_audio_to_tensor);
783+ m.impl (" _encode_audio_to_file_like" , &_encode_audio_to_file_like);
721784 m.impl (" seek_to_pts" , &seek_to_pts);
722785 m.impl (" add_video_stream" , &add_video_stream);
723786 m.impl (" _add_video_stream" , &_add_video_stream);
0 commit comments