Skip to content

Commit 781f956

Browse files
committed
Fix cuda
1 parent c06aa94 commit 781f956

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ void CpuDeviceInterface::initialize(
9494
if (areTransformsSwScaleCompatible &&
9595
(userRequestedSwScale || isWidthSwScaleCompatible)) {
9696
colorConversionLibrary_ = ColorConversionLibrary::SWSCALE;
97+
98+
// SCOTT NEXT TODO: set swsFlags_
99+
97100
} else {
98101
colorConversionLibrary_ = ColorConversionLibrary::FILTERGRAPH;
99102

@@ -112,6 +115,8 @@ void CpuDeviceInterface::initialize(
112115
filters_ = filters.str();
113116
}
114117
}
118+
119+
initialized_ = true;
115120
}
116121

117122
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
@@ -127,6 +132,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
127132
UniqueAVFrame& avFrame,
128133
FrameOutput& frameOutput,
129134
std::optional<torch::Tensor> preAllocatedOutputTensor) {
135+
TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized.");
130136
if (preAllocatedOutputTensor.has_value()) {
131137
auto shape = preAllocatedOutputTensor.value().sizes();
132138
TORCH_CHECK(
@@ -167,7 +173,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
167173
prevSwsFrameContext_ = swsFrameContext;
168174
}
169175
int resultHeight =
170-
convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor);
176+
convertAVFrameToTensorUsingSwScale(avFrame, outputTensor);
171177
// If this check failed, it would mean that the frame wasn't reshaped to
172178
// the expected height.
173179
// TODO: Can we do the same check for width?
@@ -227,7 +233,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
227233
}
228234
}
229235

230-
int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
236+
int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
231237
const UniqueAVFrame& avFrame,
232238
torch::Tensor& outputTensor) {
233239
uint8_t* pointers[4] = {

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class CpuDeviceInterface : public DeviceInterface {
3737
std::nullopt) override;
3838

3939
private:
40-
int convertAVFrameToTensorUsingSwsScale(
40+
int convertAVFrameToTensorUsingSwScale(
4141
const UniqueAVFrame& avFrame,
4242
torch::Tensor& outputTensor);
4343

@@ -71,6 +71,10 @@ class CpuDeviceInterface : public DeviceInterface {
7171
AVRational timeBase_;
7272
FrameDims outputDims_;
7373

74+
// If we use swscale for resizing, the flags control the resizing algorithm.
75+
// We exclusively get the value from the ResizeTransform.
76+
int swsFlags_ = 0;
77+
7478
// The copy filter just copies the input to the output. Computationally, it
7579
// should be a no-op. If we get no user-provided transforms, we will use the
7680
// copy filter.
@@ -85,6 +89,8 @@ class CpuDeviceInterface : public DeviceInterface {
8589
// be created before decoding a new frame.
8690
SwsFrameContext prevSwsFrameContext_;
8791
FiltersContext prevFiltersContext_;
92+
93+
bool initialized_ = false;
8894
};
8995

9096
} // namespace facebook::torchcodec

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,17 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
216216
// Typically that happens if the video's encoder isn't supported by NVDEC.
217217
// Below, we choose to convert the frame's color-space using the CPU
218218
// codepath, and send it back to the GPU at the very end.
219+
//
219220
// TODO: A possibly better solution would be to send the frame to the GPU
220221
// first, and do the color conversion there.
222+
//
223+
// TODO: If we're going to keep this around, we should probably cache it?
221224
auto cpuDevice = torch::Device(torch::kCPU);
222225
auto cpuInterface = createDeviceInterface(cpuDevice);
226+
TORCH_CHECK(
227+
cpuInterface != nullptr, "Failed to create CPU device interface");
228+
cpuDeviceInterface->initialize(
229+
nullptr, VideoStreamOptions(), {}, timeBase_, outputDims_);
223230

224231
FrameOutput cpuFrameOutput;
225232
cpuInterface->convertAVFrameToFrameOutput(

0 commit comments

Comments
 (0)