@@ -154,18 +154,6 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
154154#endif
155155}
156156
157- torch::Tensor allocateDeviceTensor (
158- at::IntArrayRef shape,
159- torch::Device device,
160- const torch::Dtype dtype = torch::kUInt8 ) {
161- return torch::empty (
162- shape,
163- torch::TensorOptions ()
164- .dtype (dtype)
165- .layout (torch::kStrided )
166- .device (device));
167- }
168-
169157void throwErrorIfNonCudaDevice (const torch::Device& device) {
170158 TORCH_CHECK (
171159 device.type () != torch::kCPU ,
@@ -202,7 +190,7 @@ void convertAVFrameToDecodedOutputOnCuda(
202190 AVCodecContext* codecContext,
203191 VideoDecoder::RawDecodedOutput& rawOutput,
204192 VideoDecoder::DecodedOutput& output,
205- std::optional< torch::Tensor> preAllocatedOutputTensor) {
193+ torch::Tensor preAllocatedOutputTensor) {
206194 AVFrame* src = rawOutput.frame .get ();
207195
208196 TORCH_CHECK (
@@ -213,22 +201,6 @@ void convertAVFrameToDecodedOutputOnCuda(
213201 int height = options.height .value_or (codecContext->height );
214202 NppiSize oSizeROI = {width, height};
215203 Npp8u* input[2 ] = {src->data [0 ], src->data [1 ]};
216- torch::Tensor& dst = output.frame ;
217- if (preAllocatedOutputTensor.has_value ()) {
218- dst = preAllocatedOutputTensor.value ();
219- auto shape = dst.sizes ();
220- TORCH_CHECK (
221- (shape.size () == 3 ) && (shape[0 ] == height) && (shape[1 ] == width) &&
222- (shape[2 ] == 3 ),
223- " Expected tensor of shape " ,
224- height,
225- " x" ,
226- width,
227- " x3, got " ,
228- shape);
229- } else {
230- dst = allocateDeviceTensor ({height, width, 3 }, options.device );
231- }
232204
233205 // Use the user-requested GPU for running the NPP kernel.
234206 c10::cuda::CUDAGuard deviceGuard (device);
@@ -238,8 +210,8 @@ void convertAVFrameToDecodedOutputOnCuda(
238210 NppStatus status = nppiNV12ToRGB_8u_P2C3R (
239211 input,
240212 src->linesize [0 ],
241- static_cast <Npp8u*>(dst .data_ptr ()),
242- dst .stride (0 ),
213+ static_cast <Npp8u*>(preAllocatedOutputTensor .data_ptr ()),
214+ preAllocatedOutputTensor .stride (0 ),
243215 oSizeROI);
244216 TORCH_CHECK (status == NPP_SUCCESS, " Failed to convert NV12 frame." );
245217 // Make the pytorch stream wait for the npp kernel to finish before using the
0 commit comments