Skip to content

Commit 685583c

Browse files
author
Molly Xu
committed
address comments
1 parent 7dd12c1 commit 685583c

File tree

8 files changed

+11
-19
lines changed

8 files changed

+11
-19
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,6 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12(
814814
void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
815815
UniqueAVFrame& avFrame,
816816
FrameOutput& frameOutput,
817-
[[maybe_unused]] AVMediaType mediaType,
818817
std::optional<torch::Tensor> preAllocatedOutputTensor) {
819818
UniqueAVFrame gpuFrame =
820819
cpuFallback_ ? transferCpuFrameToGpuNV12(avFrame) : std::move(avFrame);

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@ class BetaCudaDeviceInterface : public DeviceInterface {
4646
void convertAVFrameToFrameOutput(
4747
UniqueAVFrame& avFrame,
4848
FrameOutput& frameOutput,
49-
AVMediaType mediaType,
50-
std::optional<torch::Tensor> preAllocatedOutputTensor =
51-
std::nullopt) override;
49+
std::optional<torch::Tensor> preAllocatedOutputTensor) override;
5250

5351
int sendPacket(ReferenceAVPacket& packet) override;
5452
int sendEOFPacket() override;

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ void CpuDeviceInterface::initializeVideo(
3535
const VideoStreamOptions& videoStreamOptions,
3636
const std::vector<std::unique_ptr<Transform>>& transforms,
3737
const std::optional<FrameDims>& resizedOutputDims) {
38+
avMediaType_ = AVMEDIA_TYPE_VIDEO;
3839
videoStreamOptions_ = videoStreamOptions;
3940
resizedOutputDims_ = resizedOutputDims;
4041

@@ -88,6 +89,7 @@ void CpuDeviceInterface::initializeVideo(
8889

8990
void CpuDeviceInterface::initializeAudio(
9091
const AudioStreamOptions& audioStreamOptions) {
92+
avMediaType_ = AVMEDIA_TYPE_AUDIO;
9193
audioStreamOptions_ = audioStreamOptions;
9294
initialized_ = true;
9395
}
@@ -123,11 +125,10 @@ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary(
123125
void CpuDeviceInterface::convertAVFrameToFrameOutput(
124126
UniqueAVFrame& avFrame,
125127
FrameOutput& frameOutput,
126-
AVMediaType mediaType,
127128
std::optional<torch::Tensor> preAllocatedOutputTensor) {
128129
TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized.");
129130

130-
if (mediaType == AVMEDIA_TYPE_AUDIO) {
131+
if (avMediaType_ == AVMEDIA_TYPE_AUDIO) {
131132
convertAudioAVFrameToFrameOutput(avFrame, frameOutput);
132133
} else {
133134
convertVideoAVFrameToFrameOutput(
@@ -390,7 +391,8 @@ std::optional<torch::Tensor> CpuDeviceInterface::maybeFlushAudioBuffers() {
390391
if (!swrContext_) {
391392
return std::nullopt;
392393
}
393-
auto numRemainingSamples = swr_get_out_samples(swrContext_.get(), 0);
394+
auto numRemainingSamples = // this is an upper bound
395+
swr_get_out_samples(swrContext_.get(), 0);
394396

395397
if (numRemainingSamples == 0) {
396398
return std::nullopt;

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@ class CpuDeviceInterface : public DeviceInterface {
4141
void convertAVFrameToFrameOutput(
4242
UniqueAVFrame& avFrame,
4343
FrameOutput& frameOutput,
44-
AVMediaType mediaType,
45-
std::optional<torch::Tensor> preAllocatedOutputTensor =
46-
std::nullopt) override;
44+
std::optional<torch::Tensor> preAllocatedOutputTensor) override;
4745

4846
std::string getDetails() override;
4947

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24(
238238
void CudaDeviceInterface::convertAVFrameToFrameOutput(
239239
UniqueAVFrame& avFrame,
240240
FrameOutput& frameOutput,
241-
[[maybe_unused]] AVMediaType mediaType,
242241
std::optional<torch::Tensor> preAllocatedOutputTensor) {
243242
validatePreAllocatedTensorShape(preAllocatedOutputTensor, avFrame);
244243

@@ -272,8 +271,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
272271
} else {
273272
// Reason 2 above. We need to do a full conversion which requires an
274273
// actual CPU device.
275-
cpuInterface_->convertAVFrameToFrameOutput(
276-
avFrame, cpuFrameOutput, AVMEDIA_TYPE_VIDEO);
274+
cpuInterface_->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput);
277275
}
278276

279277
// Finally, we need to send the frame back to the GPU. Note that the

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@ class CudaDeviceInterface : public DeviceInterface {
3737
void convertAVFrameToFrameOutput(
3838
UniqueAVFrame& avFrame,
3939
FrameOutput& frameOutput,
40-
AVMediaType mediaType,
41-
std::optional<torch::Tensor> preAllocatedOutputTensor =
42-
std::nullopt) override;
40+
std::optional<torch::Tensor> preAllocatedOutputTensor) override;
4341

4442
std::string getDetails() override;
4543

src/torchcodec/_core/DeviceInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ class DeviceInterface {
9090
virtual void convertAVFrameToFrameOutput(
9191
UniqueAVFrame& avFrame,
9292
FrameOutput& frameOutput,
93-
AVMediaType mediaType,
9493
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt) = 0;
9594

9695
// ------------------------------------------
@@ -142,6 +141,7 @@ class DeviceInterface {
142141
protected:
143142
torch::Device device_;
144143
SharedAVCodecContext codecContext_;
144+
AVMediaType avMediaType_;
145145
};
146146

147147
using CreateDeviceInterfaceFn =

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,15 +1289,14 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput(
12891289
std::optional<torch::Tensor> preAllocatedOutputTensor) {
12901290
// Convert the frame to tensor.
12911291
FrameOutput frameOutput;
1292-
auto& streamInfo = streamInfos_[activeStreamIndex_];
12931292
frameOutput.ptsSeconds = ptsToSeconds(
12941293
getPtsOrDts(avFrame),
12951294
formatContext_->streams[activeStreamIndex_]->time_base);
12961295
frameOutput.durationSeconds = ptsToSeconds(
12971296
getDuration(avFrame),
12981297
formatContext_->streams[activeStreamIndex_]->time_base);
12991298
deviceInterface_->convertAVFrameToFrameOutput(
1300-
avFrame, frameOutput, streamInfo.avMediaType, preAllocatedOutputTensor);
1299+
avFrame, frameOutput, preAllocatedOutputTensor);
13011300
return frameOutput;
13021301
}
13031302

0 commit comments

Comments
 (0)