Skip to content

Commit fcbe4a3

Browse files
committed
Handle GPU?
1 parent 5c6e23c commit fcbe4a3

File tree

4 files changed

+11
-35
lines changed

4 files changed

+11
-35
lines changed

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ void convertAVFrameToDecodedOutputOnCuda(
2020
AVCodecContext* codecContext,
2121
VideoDecoder::RawDecodedOutput& rawOutput,
2222
VideoDecoder::DecodedOutput& output,
23-
std::optional<torch::Tensor> preAllocatedOutputTensor) {
23+
torch::Tensor preAllocatedOutputTensor) {
2424
throwUnsupportedDeviceError(device);
2525
}
2626

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -154,18 +154,6 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
154154
#endif
155155
}
156156

157-
torch::Tensor allocateDeviceTensor(
158-
at::IntArrayRef shape,
159-
torch::Device device,
160-
const torch::Dtype dtype = torch::kUInt8) {
161-
return torch::empty(
162-
shape,
163-
torch::TensorOptions()
164-
.dtype(dtype)
165-
.layout(torch::kStrided)
166-
.device(device));
167-
}
168-
169157
void throwErrorIfNonCudaDevice(const torch::Device& device) {
170158
TORCH_CHECK(
171159
device.type() != torch::kCPU,
@@ -202,7 +190,7 @@ void convertAVFrameToDecodedOutputOnCuda(
202190
AVCodecContext* codecContext,
203191
VideoDecoder::RawDecodedOutput& rawOutput,
204192
VideoDecoder::DecodedOutput& output,
205-
std::optional<torch::Tensor> preAllocatedOutputTensor) {
193+
torch::Tensor preAllocatedOutputTensor) {
206194
AVFrame* src = rawOutput.frame.get();
207195

208196
TORCH_CHECK(
@@ -213,22 +201,6 @@ void convertAVFrameToDecodedOutputOnCuda(
213201
int height = options.height.value_or(codecContext->height);
214202
NppiSize oSizeROI = {width, height};
215203
Npp8u* input[2] = {src->data[0], src->data[1]};
216-
torch::Tensor& dst = output.frame;
217-
if (preAllocatedOutputTensor.has_value()) {
218-
dst = preAllocatedOutputTensor.value();
219-
auto shape = dst.sizes();
220-
TORCH_CHECK(
221-
(shape.size() == 3) && (shape[0] == height) && (shape[1] == width) &&
222-
(shape[2] == 3),
223-
"Expected tensor of shape ",
224-
height,
225-
"x",
226-
width,
227-
"x3, got ",
228-
shape);
229-
} else {
230-
dst = allocateDeviceTensor({height, width, 3}, options.device);
231-
}
232204

233205
// Use the user-requested GPU for running the NPP kernel.
234206
c10::cuda::CUDAGuard deviceGuard(device);
@@ -238,8 +210,8 @@ void convertAVFrameToDecodedOutputOnCuda(
238210
NppStatus status = nppiNV12ToRGB_8u_P2C3R(
239211
input,
240212
src->linesize[0],
241-
static_cast<Npp8u*>(dst.data_ptr()),
242-
dst.stride(0),
213+
static_cast<Npp8u*>(preAllocatedOutputTensor.data_ptr()),
214+
preAllocatedOutputTensor.stride(0),
243215
oSizeROI);
244216
TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame.");
245217
// Make the pytorch stream wait for the npp kernel to finish before using the

src/torchcodec/decoders/_core/DeviceInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ void convertAVFrameToDecodedOutputOnCuda(
3838
AVCodecContext* codecContext,
3939
VideoDecoder::RawDecodedOutput& rawOutput,
4040
VideoDecoder::DecodedOutput& output,
41-
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
41+
torch::Tensor preAllocatedOutputTensor);
4242

4343
void releaseContextOnCuda(
4444
const torch::Device& device,

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,14 @@ torch::Tensor VideoDecoder::allocateEmptyHWCTensorForStream(
195195
auto height = options.height.value_or(*metadata.height);
196196
auto width = options.width.value_or(*metadata.width);
197197

198+
auto tensorOptions = torch::TensorOptions()
199+
.dtype(torch::kUInt8)
200+
.layout(torch::kStrided)
201+
.device(options.device.type());
198202
if (numFrames.has_value()) {
199-
return torch::empty({numFrames.value(), height, width, 3}, {torch::kUInt8});
203+
return torch::empty({numFrames.value(), height, width, 3}, tensorOptions);
200204
} else {
201-
return torch::empty({height, width, 3}, {torch::kUInt8});
205+
return torch::empty({height, width, 3}, tensorOptions);
202206
}
203207
}
204208

0 commit comments

Comments
 (0)