1010#include < string>
1111#include " c10/core/SymIntArrayRef.h"
1212#include " c10/util/Exception.h"
13+ #include " src/torchcodec/_core/AVIOFileLikeContext.h"
1314#include " src/torchcodec/_core/AVIOTensorContext.h"
1415#include " src/torchcodec/_core/Encoder.h"
1516#include " src/torchcodec/_core/SingleStreamDecoder.h"
@@ -33,8 +34,12 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3334 " encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()" );
3435 m.def (
3536 " encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor" );
37+ m.def (
38+ " _encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()" );
3639 m.def (
3740 " create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor" );
41+ m.def (
42+ " _create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor" );
3843 m.def (" _convert_to_tensor(int decoder_ptr) -> Tensor" );
3944 m.def (
4045 " _add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()" );
@@ -210,6 +215,24 @@ at::Tensor create_from_tensor(
210215 return wrapDecoderPointerToTensor (std::move (uniqueDecoder));
211216}
212217
218+ at::Tensor _create_from_file_like (
219+ int64_t file_like_context,
220+ std::optional<std::string_view> seek_mode) {
221+ auto fileLikeContext =
222+ reinterpret_cast <AVIOFileLikeContext*>(file_like_context);
223+ TORCH_CHECK (fileLikeContext != nullptr , " file_like must be a valid pointer" );
224+ std::unique_ptr<AVIOFileLikeContext> contextHolder (fileLikeContext);
225+
226+ SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact;
227+ if (seek_mode.has_value ()) {
228+ realSeek = seekModeFromString (seek_mode.value ());
229+ }
230+
231+ std::unique_ptr<SingleStreamDecoder> uniqueDecoder =
232+ std::make_unique<SingleStreamDecoder>(std::move (contextHolder), realSeek);
233+ return wrapDecoderPointerToTensor (std::move (uniqueDecoder));
234+ }
235+
213236at::Tensor _convert_to_tensor (int64_t decoder_ptr) {
214237 auto decoder = reinterpret_cast <SingleStreamDecoder*>(decoder_ptr);
215238 std::unique_ptr<SingleStreamDecoder> uniqueDecoder (decoder);
@@ -441,6 +464,36 @@ at::Tensor encode_audio_to_tensor(
441464 .encodeToTensor ();
442465}
443466
467+ void _encode_audio_to_file_like (
468+ const at::Tensor& samples,
469+ int64_t sample_rate,
470+ std::string_view format,
471+ int64_t file_like_context,
472+ std::optional<int64_t > bit_rate = std::nullopt ,
473+ std::optional<int64_t > num_channels = std::nullopt ,
474+ std::optional<int64_t > desired_sample_rate = std::nullopt ) {
475+ auto fileLikeContext =
476+ reinterpret_cast <AVIOFileLikeContext*>(file_like_context);
477+ TORCH_CHECK (
478+ fileLikeContext != nullptr , " file_like_context must be a valid pointer" );
479+ std::unique_ptr<AVIOFileLikeContext> avioContextHolder (fileLikeContext);
480+
481+ AudioStreamOptions audioStreamOptions;
482+ audioStreamOptions.bitRate = validateOptionalInt64ToInt (bit_rate, " bit_rate" );
483+ audioStreamOptions.numChannels =
484+ validateOptionalInt64ToInt (num_channels, " num_channels" );
485+ audioStreamOptions.sampleRate =
486+ validateOptionalInt64ToInt (desired_sample_rate, " desired_sample_rate" );
487+
488+ AudioEncoder encoder (
489+ samples,
490+ validateInt64ToInt (sample_rate, " sample_rate" ),
491+ format,
492+ std::move (avioContextHolder),
493+ audioStreamOptions);
494+ encoder.encode ();
495+ }
496+
444497// For testing only. We need to implement this operation as a core library
445498// function because what we're testing is round-tripping pts values as
446499// double-precision floating point numbers from C++ to Python and back to C++.
@@ -694,6 +747,7 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) {
694747TORCH_LIBRARY_IMPL (torchcodec_ns, BackendSelect, m) {
695748 m.impl (" create_from_file" , &create_from_file);
696749 m.impl (" create_from_tensor" , &create_from_tensor);
750+ m.impl (" _create_from_file_like" , &_create_from_file_like);
697751 m.impl (" _convert_to_tensor" , &_convert_to_tensor);
698752 m.impl (
699753 " _get_json_ffmpeg_library_versions" , &_get_json_ffmpeg_library_versions);
@@ -702,6 +756,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
702756TORCH_LIBRARY_IMPL (torchcodec_ns, CPU, m) {
703757 m.impl (" encode_audio_to_file" , &encode_audio_to_file);
704758 m.impl (" encode_audio_to_tensor" , &encode_audio_to_tensor);
759+ m.impl (" _encode_audio_to_file_like" , &_encode_audio_to_file_like);
705760 m.impl (" seek_to_pts" , &seek_to_pts);
706761 m.impl (" add_video_stream" , &add_video_stream);
707762 m.impl (" _add_video_stream" , &_add_video_stream);
0 commit comments