@@ -94,9 +94,9 @@ SingleStreamDecoder::SingleStreamDecoder(
9494
9595SingleStreamDecoder::~SingleStreamDecoder () {
9696 for (auto & [streamIndex, streamInfo] : streamInfos_) {
97- auto & device = streamInfo.videoStreamOptions . device ;
98- if (device ) {
99- device ->releaseContext (streamInfo.codecContext .get ());
97+ auto & deviceInterface = streamInfo.deviceInterface ;
98+ if (deviceInterface ) {
99+ deviceInterface ->releaseContext (streamInfo.codecContext .get ());
100100 }
101101 }
102102}
@@ -386,7 +386,7 @@ torch::Tensor SingleStreamDecoder::getKeyFrameIndices() {
386386void SingleStreamDecoder::addStream (
387387 int streamIndex,
388388 AVMediaType mediaType,
389- DeviceInterface* device,
389+ const torch::Device& device,
390390 std::optional<int > ffmpegThreadCount) {
391391 TORCH_CHECK (
392392 activeStreamIndex_ == NO_ACTIVE_STREAM,
@@ -414,6 +414,7 @@ void SingleStreamDecoder::addStream(
414414 streamInfo.timeBase = formatContext_->streams [activeStreamIndex_]->time_base ;
415415 streamInfo.stream = formatContext_->streams [activeStreamIndex_];
416416 streamInfo.avMediaType = mediaType;
417+ streamInfo.deviceInterface = createDeviceInterface (device);
417418
418419 // This should never happen, checking just to be safe.
419420 TORCH_CHECK (
@@ -425,9 +426,10 @@ void SingleStreamDecoder::addStream(
425426 // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
426427 // addStream() which is supposed to be generic
427428 if (mediaType == AVMEDIA_TYPE_VIDEO) {
428- if (device ) {
429+ if (streamInfo. deviceInterface ) {
429430 avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream (
430- device->findCodec (streamInfo.stream ->codecpar ->codec_id )
431+ streamInfo.deviceInterface
432+ ->findCodec (streamInfo.stream ->codecpar ->codec_id )
431433 .value_or (avCodec));
432434 }
433435 }
@@ -445,8 +447,8 @@ void SingleStreamDecoder::addStream(
445447
446448 // TODO_CODE_QUALITY same as above.
447449 if (mediaType == AVMEDIA_TYPE_VIDEO) {
448- if (device ) {
449- device ->initializeContext (codecContext);
450+ if (streamInfo. deviceInterface ) {
451+ streamInfo. deviceInterface ->initializeContext (codecContext);
450452 }
451453 }
452454
@@ -476,7 +478,7 @@ void SingleStreamDecoder::addVideoStream(
476478 addStream (
477479 streamIndex,
478480 AVMEDIA_TYPE_VIDEO,
479- videoStreamOptions.device . get () ,
481+ videoStreamOptions.device ,
480482 videoStreamOptions.ffmpegThreadCount );
481483
482484 auto & streamMetadata =
@@ -1217,11 +1219,11 @@ SingleStreamDecoder::convertAVFrameToFrameOutput(
12171219 formatContext_->streams [activeStreamIndex_]->time_base );
12181220 if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
12191221 convertAudioAVFrameToFrameOutputOnCPU (avFrame, frameOutput);
1220- } else if (!streamInfo.videoStreamOptions . device ) {
1222+ } else if (!streamInfo.deviceInterface ) {
12211223 convertAVFrameToFrameOutputOnCPU (
12221224 avFrame, frameOutput, preAllocatedOutputTensor);
1223- } else if (streamInfo.videoStreamOptions . device ) {
1224- streamInfo.videoStreamOptions . device ->convertAVFrameToFrameOutput (
1225+ } else if (streamInfo.deviceInterface ) {
1226+ streamInfo.deviceInterface ->convertAVFrameToFrameOutput (
12251227 streamInfo.videoStreamOptions ,
12261228 avFrame,
12271229 frameOutput,
@@ -1564,10 +1566,8 @@ SingleStreamDecoder::FrameBatchOutput::FrameBatchOutput(
15641566 videoStreamOptions, streamMetadata);
15651567 int height = frameDims.height ;
15661568 int width = frameDims.width ;
1567- torch::Device device = (videoStreamOptions.device )
1568- ? videoStreamOptions.device ->device ()
1569- : torch::kCPU ;
1570- data = allocateEmptyHWCTensor (height, width, device, numFrames);
1569+ data = allocateEmptyHWCTensor (
1570+ height, width, videoStreamOptions.device , numFrames);
15711571}
15721572
15731573torch::Tensor allocateEmptyHWCTensor (
0 commit comments