|
| 1 | +#include <iostream> |
| 2 | +extern "C" { |
| 3 | +#include <libavutil/pixdesc.h> |
| 4 | +} |
1 | 5 | #include <sstream> |
2 | 6 |
|
3 | 7 | #include "src/torchcodec/_core/AVIOTensorContext.h" |
@@ -507,4 +511,328 @@ void AudioEncoder::flushBuffers() { |
507 | 511 |
|
508 | 512 | encodeFrame(autoAVPacket, UniqueAVFrame(nullptr)); |
509 | 513 | } |
| 514 | + |
| 515 | +namespace { |
| 516 | + |
| 517 | +torch::Tensor validateFrames(const torch::Tensor& frames) { |
| 518 | + TORCH_CHECK( |
| 519 | + frames.dtype() == torch::kFloat32 || frames.dtype() == torch::kUInt8, |
| 520 | + "frames must have float32 or kUInt8 dtype, got ", |
| 521 | + frames.dtype()); |
| 522 | + TORCH_CHECK( |
| 523 | + frames.dim() == 4, |
| 524 | + "frames must have 4 dimensions (N, H, W, C) or (N, C, H, W), got ", |
| 525 | + frames.dim()); |
| 526 | + |
| 527 | + return frames.contiguous(); |
| 528 | +} |
| 529 | + |
| 530 | +struct TensorFormat { |
| 531 | + bool isNCHW; |
| 532 | + int numChannels; |
| 533 | + int width; |
| 534 | + int height; |
| 535 | + AVPixelFormat pixelFormat; |
| 536 | +}; |
| 537 | + |
| 538 | +TensorFormat analyzeTensorFormat(const torch::Tensor& frames) { |
| 539 | + auto sizes = frames.sizes(); |
| 540 | + TORCH_CHECK( |
| 541 | + sizes.size() == 4, "Expected 4D tensor (N, C, H, W) or (N, H, W, C)"); |
| 542 | + |
| 543 | + bool isNCHW = sizes[1] == 3 || sizes[1] == 4; |
| 544 | + |
| 545 | + int numChannels = isNCHW ? sizes[1] : sizes[3]; |
| 546 | + int height = isNCHW ? sizes[2] : sizes[1]; |
| 547 | + int width = isNCHW ? sizes[3] : sizes[2]; |
| 548 | + |
| 549 | + AVPixelFormat pixelFormat; |
| 550 | + if (isNCHW) { |
| 551 | + pixelFormat = |
| 552 | + (numChannels == 3) ? AV_PIX_FMT_GBRP : AV_PIX_FMT_GBRAP; // Planar |
| 553 | + } else { |
| 554 | + pixelFormat = |
| 555 | + (numChannels == 3) ? AV_PIX_FMT_RGB24 : AV_PIX_FMT_RGBA; // Packed |
| 556 | + } |
| 557 | + return {isNCHW, numChannels, width, height, pixelFormat}; |
| 558 | +} |
| 559 | + |
| 560 | +} // namespace |
| 561 | + |
| 562 | +VideoEncoder::~VideoEncoder() { |
| 563 | + close_avio(); |
| 564 | +} |
| 565 | + |
| 566 | +void VideoEncoder::close_avio() { |
| 567 | + if (avFormatContext_ && avFormatContext_->pb) { |
| 568 | + avio_flush(avFormatContext_->pb); |
| 569 | + |
| 570 | + if (!avioContextHolder_) { |
| 571 | + avio_close(avFormatContext_->pb); |
| 572 | + // avoids closing again in destructor, which would segfault. |
| 573 | + avFormatContext_->pb = nullptr; |
| 574 | + } |
| 575 | + } |
| 576 | +} |
| 577 | + |
| 578 | +VideoEncoder::VideoEncoder( |
| 579 | + const torch::Tensor& frames, |
| 580 | + int frameRate, |
| 581 | + std::string_view fileName, |
| 582 | + const VideoStreamOptions& videoStreamOptions) |
| 583 | + : frames_(validateFrames(frames)), frameRate_(frameRate) { |
| 584 | + setFFmpegLogLevel(); |
| 585 | + |
| 586 | + // Allocate output format context |
| 587 | + AVFormatContext* avFormatContext = nullptr; |
| 588 | + int status = avformat_alloc_output_context2( |
| 589 | + &avFormatContext, nullptr, nullptr, fileName.data()); |
| 590 | + |
| 591 | + TORCH_CHECK( |
| 592 | + avFormatContext != nullptr, |
| 593 | + "Couldn't allocate AVFormatContext. ", |
| 594 | + "The destination file is ", |
| 595 | + fileName, |
| 596 | + ", check the desired extension? ", |
| 597 | + getFFMPEGErrorStringFromErrorCode(status)); |
| 598 | + avFormatContext_.reset(avFormatContext); |
| 599 | + |
| 600 | + status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE); |
| 601 | + TORCH_CHECK( |
| 602 | + status >= 0, |
| 603 | + "avio_open failed. The destination file is ", |
| 604 | + fileName, |
| 605 | + ", make sure it's a valid path? ", |
| 606 | + getFFMPEGErrorStringFromErrorCode(status)); |
| 607 | + |
| 608 | + initializeEncoder(videoStreamOptions); |
| 609 | +} |
| 610 | + |
| 611 | +void VideoEncoder::initializeEncoder( |
| 612 | + const VideoStreamOptions& videoStreamOptions) { |
| 613 | + // TODO-VideoEncoder: Allow FFmpeg to pick codec based on container format? |
| 614 | + // Currently, this causes errors for some containers (avi) |
| 615 | + // const AVCodec* avCodec = |
| 616 | + // avcodec_find_encoder(avFormatContext_->oformat->video_codec); |
| 617 | + const AVCodec* avCodec = avcodec_find_encoder(AV_CODEC_ID_H264); |
| 618 | + TORCH_CHECK(avCodec != nullptr, "Video codec not found"); |
| 619 | + |
| 620 | + AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec); |
| 621 | + TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); |
| 622 | + avCodecContext_.reset(avCodecContext); |
| 623 | + |
| 624 | + // Set encoding options |
| 625 | + // TODO-VideoEncoder: Allow bitrate to be set |
| 626 | + std::optional<int> desiredBitRate = videoStreamOptions.bitRate; |
| 627 | + if (desiredBitRate.has_value()) { |
| 628 | + TORCH_CHECK( |
| 629 | + *desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0."); |
| 630 | + } |
| 631 | + avCodecContext_->bit_rate = desiredBitRate.value_or(0); |
| 632 | + // TODO-VideoEncoder: Verify that frame_rate and time_base are correct |
| 633 | + avCodecContext_->time_base = {1, frameRate_}; |
| 634 | + avCodecContext_->framerate = {frameRate_, 1}; |
| 635 | + |
| 636 | + // Analyze tensor format once and store results in member variables |
| 637 | + TensorFormat format = analyzeTensorFormat(frames_); |
| 638 | + isNCHW_ = format.isNCHW; |
| 639 | + inWidth_ = format.width; |
| 640 | + inHeight_ = format.height; |
| 641 | + inPixelFormat_ = format.pixelFormat; |
| 642 | + |
| 643 | + // Use specified dimensions or input dimensions |
| 644 | + // TODO-VideoEncoder: Allow height and width to be set |
| 645 | + outWidth_ = videoStreamOptions.width.value_or(inWidth_); |
| 646 | + outHeight_ = videoStreamOptions.height.value_or(inHeight_); |
| 647 | + |
| 648 | + // Use YUV420P as default output format |
| 649 | + outPixelFormat_ = AV_PIX_FMT_YUV420P; |
| 650 | + |
| 651 | + // Configure codec parameters |
| 652 | + avCodecContext_->codec_id = avCodec->id; |
| 653 | + avCodecContext_->width = outWidth_; |
| 654 | + avCodecContext_->height = outHeight_; |
| 655 | + avCodecContext_->pix_fmt = outPixelFormat_; |
| 656 | + avCodecContext_->time_base = {1, frameRate_}; |
| 657 | + |
| 658 | + // TODO-VideoEncoder: Allow GOP size and max B-frames to be set |
| 659 | + if (videoStreamOptions.gopSize.has_value()) { |
| 660 | + avCodecContext_->gop_size = *videoStreamOptions.gopSize; |
| 661 | + } else { |
| 662 | + avCodecContext_->gop_size = 12; // Default GOP size |
| 663 | + } |
| 664 | + |
| 665 | + if (videoStreamOptions.maxBFrames.has_value()) { |
| 666 | + avCodecContext_->max_b_frames = *videoStreamOptions.maxBFrames; |
| 667 | + } else { |
| 668 | + avCodecContext_->max_b_frames = 2; // Default max B-frames |
| 669 | + } |
| 670 | + |
| 671 | + int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); |
| 672 | + TORCH_CHECK( |
| 673 | + status == AVSUCCESS, |
| 674 | + "avcodec_open2 failed: ", |
| 675 | + getFFMPEGErrorStringFromErrorCode(status)); |
| 676 | + |
| 677 | + AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr); |
| 678 | + TORCH_CHECK(avStream != nullptr, "Couldn't create new stream."); |
| 679 | + status = avcodec_parameters_from_context( |
| 680 | + avStream->codecpar, avCodecContext_.get()); |
| 681 | + TORCH_CHECK( |
| 682 | + status == AVSUCCESS, |
| 683 | + "avcodec_parameters_from_context failed: ", |
| 684 | + getFFMPEGErrorStringFromErrorCode(status)); |
| 685 | + streamIndex_ = avStream->index; |
| 686 | +} |
| 687 | + |
| 688 | +UniqueAVFrame VideoEncoder::convertTensorToAVFrame( |
| 689 | + const torch::Tensor& frameTensor, |
| 690 | + int frameIndex) { |
| 691 | + // Initialize and cache scaling context if it does not exist |
| 692 | + if (!swsContext_) { |
| 693 | + swsContext_.reset(sws_getContext( |
| 694 | + inWidth_, |
| 695 | + inHeight_, |
| 696 | + inPixelFormat_, |
| 697 | + outWidth_, |
| 698 | + outHeight_, |
| 699 | + outPixelFormat_, |
| 700 | + SWS_BILINEAR, |
| 701 | + nullptr, |
| 702 | + nullptr, |
| 703 | + nullptr)); |
| 704 | + TORCH_CHECK(swsContext_ != nullptr, "Failed to create scaling context"); |
| 705 | + } |
| 706 | + |
| 707 | + UniqueAVFrame avFrame(av_frame_alloc()); |
| 708 | + TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); |
| 709 | + |
| 710 | + // Set output frame properties |
| 711 | + avFrame->format = outPixelFormat_; |
| 712 | + avFrame->width = outWidth_; |
| 713 | + avFrame->height = outHeight_; |
| 714 | + avFrame->pts = frameIndex; |
| 715 | + |
| 716 | + int status = av_frame_get_buffer(avFrame.get(), 32); |
| 717 | + TORCH_CHECK(status >= 0, "Failed to allocate frame buffer"); |
| 718 | + |
| 719 | + // Need to convert/scale the frame |
| 720 | + // Create temporary frame with input format |
| 721 | + UniqueAVFrame inputFrame(av_frame_alloc()); |
| 722 | + TORCH_CHECK(inputFrame != nullptr, "Failed to allocate input AVFrame"); |
| 723 | + |
| 724 | + inputFrame->format = inPixelFormat_; |
| 725 | + inputFrame->width = inWidth_; |
| 726 | + inputFrame->height = inHeight_; |
| 727 | + |
| 728 | + uint8_t* tensorData = static_cast<uint8_t*>(frameTensor.data_ptr()); |
| 729 | + |
| 730 | + if (isNCHW_) { |
| 731 | + int channelSize = inHeight_ * inWidth_; |
| 732 | + // Reorder RGB -> GBR for AV_PIX_FMT_GBRP or AV_PIX_FMT_GBRAP formats |
| 733 | + inputFrame->data[0] = tensorData + channelSize; |
| 734 | + inputFrame->data[1] = tensorData + (2 * channelSize); |
| 735 | + inputFrame->data[2] = tensorData; |
| 736 | + |
| 737 | + inputFrame->linesize[0] = inWidth_; // width of B channel |
| 738 | + inputFrame->linesize[1] = inWidth_; // width of G channel |
| 739 | + inputFrame->linesize[2] = inWidth_; // width of R channel |
| 740 | + } else { |
| 741 | + // NHWC is usually in packed format |
| 742 | + inputFrame->data[0] = tensorData; |
| 743 | + auto sizes = frameTensor.sizes(); |
| 744 | + // width * channels |
| 745 | + inputFrame->linesize[0] = inWidth_ * sizes[sizes.size() - 1]; |
| 746 | + } |
| 747 | + // Perform scaling/conversion |
| 748 | + status = sws_scale( |
| 749 | + swsContext_.get(), |
| 750 | + inputFrame->data, |
| 751 | + inputFrame->linesize, |
| 752 | + 0, |
| 753 | + inputFrame->height, |
| 754 | + avFrame->data, |
| 755 | + avFrame->linesize); |
| 756 | + TORCH_CHECK(status == outHeight_, "sws_scale failed"); |
| 757 | + return avFrame; |
| 758 | +} |
| 759 | + |
| 760 | +void VideoEncoder::encode() { |
| 761 | + // To be on the safe side we enforce that encode() can only be called once |
| 762 | + TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice."); |
| 763 | + encodeWasCalled_ = true; |
| 764 | + |
| 765 | + int status = avformat_write_header(avFormatContext_.get(), nullptr); |
| 766 | + TORCH_CHECK( |
| 767 | + status == AVSUCCESS, |
| 768 | + "Error in avformat_write_header: ", |
| 769 | + getFFMPEGErrorStringFromErrorCode(status)); |
| 770 | + |
| 771 | + AutoAVPacket autoAVPacket; |
| 772 | + int numFrames = frames_.sizes()[0]; |
| 773 | + for (int i = 0; i < numFrames; ++i) { |
| 774 | + torch::Tensor singleFrame = frames_.select(0, i); |
| 775 | + UniqueAVFrame avFrame = convertTensorToAVFrame(singleFrame, i); |
| 776 | + encodeFrame(autoAVPacket, avFrame); |
| 777 | + } |
| 778 | + |
| 779 | + flushBuffers(); |
| 780 | + |
| 781 | + status = av_write_trailer(avFormatContext_.get()); |
| 782 | + TORCH_CHECK( |
| 783 | + status == AVSUCCESS, |
| 784 | + "Error in av_write_trailer: ", |
| 785 | + getFFMPEGErrorStringFromErrorCode(status)); |
| 786 | + |
| 787 | + close_avio(); |
| 788 | +} |
| 789 | + |
| 790 | +void VideoEncoder::encodeFrame( |
| 791 | + AutoAVPacket& autoAVPacket, |
| 792 | + const UniqueAVFrame& avFrame) { |
| 793 | + auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); |
| 794 | + TORCH_CHECK( |
| 795 | + status == AVSUCCESS, |
| 796 | + "Error while sending frame: ", |
| 797 | + getFFMPEGErrorStringFromErrorCode(status)); |
| 798 | + |
| 799 | + while (true) { |
| 800 | + ReferenceAVPacket packet(autoAVPacket); |
| 801 | + status = avcodec_receive_packet(avCodecContext_.get(), packet.get()); |
| 802 | + if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) { |
| 803 | + if (status == AVERROR_EOF) { |
| 804 | + // Flush remaining buffered packets |
| 805 | + status = av_interleaved_write_frame(avFormatContext_.get(), nullptr); |
| 806 | + TORCH_CHECK( |
| 807 | + status == AVSUCCESS, |
| 808 | + "Failed to flush packet: ", |
| 809 | + getFFMPEGErrorStringFromErrorCode(status)); |
| 810 | + } |
| 811 | + return; |
| 812 | + } |
| 813 | + TORCH_CHECK( |
| 814 | + status >= 0, |
| 815 | + "Error receiving packet: ", |
| 816 | + getFFMPEGErrorStringFromErrorCode(status)); |
| 817 | + |
| 818 | + av_packet_rescale_ts( |
| 819 | + packet.get(), |
| 820 | + avCodecContext_->time_base, |
| 821 | + avFormatContext_->streams[streamIndex_]->time_base); |
| 822 | + packet->stream_index = streamIndex_; |
| 823 | + |
| 824 | + status = av_interleaved_write_frame(avFormatContext_.get(), packet.get()); |
| 825 | + TORCH_CHECK( |
| 826 | + status == AVSUCCESS, |
| 827 | + "Error in av_interleaved_write_frame: ", |
| 828 | + getFFMPEGErrorStringFromErrorCode(status)); |
| 829 | + } |
| 830 | +} |
| 831 | + |
| 832 | +void VideoEncoder::flushBuffers() { |
| 833 | + AutoAVPacket autoAVPacket; |
| 834 | + // Send NULL frame to signal end of input |
| 835 | + encodeFrame(autoAVPacket, UniqueAVFrame(nullptr)); |
| 836 | +} |
| 837 | + |
510 | 838 | } // namespace facebook::torchcodec |
0 commit comments