Skip to content

Commit e15a79e

Browse files
committed
Refactor NvVideoDecoder to replace deprecated nvcv_image with cvcuda tensor
- Updated NvVideoDecoder to remove the use of nvcv_image, which is deprecated, and replaced it with cvcuda tensor. - Adjusted related tensor operations and tests to ensure compatibility with the new cvcuda implementation. Signed-off-by: Abhinav Garg <abhinavg@stanford.edu>
1 parent 37a53a6 commit e15a79e

File tree

2 files changed

+47
-46
lines changed

2 files changed

+47
-46
lines changed

nemo_curator/utils/nvcodec_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -268,18 +268,18 @@ def generate_decoded_frames(self) -> list[torch.Tensor]:
268268
for packet in self.nvDemux:
269269
list_frames = self.nvDec.Decode(packet)
270270
for decoded_frame in list_frames:
271-
# TODO: Replace nvcv with cvcuda. Before that remove the use of nvcv_image. It's deprecated
272-
nvcv_tensor = cvcuda.as_tensor(cvcuda.as_image(decoded_frame.nvcv_image(), cvcuda.Format.U8))
273-
if nvcv_tensor.layout == "NCHW":
274-
nchw_shape = nvcv_tensor.shape
271+
# TODO: Remove the use of nvcv_image. It's deprecated
272+
cvcuda_tensor = cvcuda.as_tensor(cvcuda.as_image(decoded_frame.nvcv_image(), cvcuda.Format.U8))
273+
if cvcuda_tensor.layout == "NCHW":
274+
nchw_shape = cvcuda_tensor.shape
275275
nhwc_shape = (nchw_shape[0], nchw_shape[2], nchw_shape[3], nchw_shape[1])
276276
torch_nhwc = torch.empty(
277277
nhwc_shape,
278278
dtype=torch.uint8,
279279
device=f"cuda:{self.device_id}",
280280
)
281281
cvcuda_nhwc = cvcuda.as_tensor(torch_nhwc.cuda(self.device_id), "NHWC")
282-
cvcuda.reformat_into(cvcuda_nhwc, nvcv_tensor, stream=self.cvcuda_stream)
282+
cvcuda.reformat_into(cvcuda_nhwc, cvcuda_tensor, stream=self.cvcuda_stream)
283283
# Push the decoded frame with the reformatted frame to keep it alive.
284284
self.input_frame_list.put(torch_nhwc)
285285
else:

tests/utils/test_nvcodec_utils.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def test_video_batch_decoder_init_without_dependencies(self) -> None:
6767

6868
@patch("nemo_curator.utils.nvcodec_utils.cuda", None)
6969
@patch("nemo_curator.utils.nvcodec_utils.cvcuda", None)
70-
@patch("nemo_curator.utils.nvcodec_utils.nvcv", None)
7170
@patch("nemo_curator.utils.nvcodec_utils.Nvc", None)
7271
def test_py_nvc_frame_extractor_without_dependencies(self) -> None:
7372
"""Test PyNvcFrameExtractor fails gracefully when dependencies are missing."""
@@ -92,7 +91,7 @@ def test_module_imports_gracefully_without_dependencies(self) -> None:
9291
"""Test that the module can be imported even when GPU dependencies are missing."""
9392
# If we got here, the import was successful
9493
# This test verifies that import failures are handled gracefully
95-
from nemo_curator.utils import nvcodec_utils
94+
from nemo_curator.utils import nvcodec_utils # noqa: PLC0415
9695

9796
# Verify the module has expected attributes
9897
assert hasattr(nvcodec_utils, "FrameExtractionPolicy")
@@ -528,10 +527,7 @@ def test_init(self, _mock_logger: Any, mock_nvc: Any) -> None:
528527
@patch("nemo_curator.utils.nvcodec_utils.Nvc")
529528
@patch("nemo_curator.utils.nvcodec_utils.torch")
530529
@patch("nemo_curator.utils.nvcodec_utils.cvcuda")
531-
@patch("nemo_curator.utils.nvcodec_utils.nvcv")
532-
def test_generate_decoded_frames(
533-
self, _mock_nvcv: Any, _mock_cvcuda: Any, _mock_torch: Any, mock_nvc: Any
534-
) -> None:
530+
def test_generate_decoded_frames(self, _mock_cvcuda: Any, _mock_torch: Any, mock_nvc: Any) -> None:
535531
"""Test generate_decoded_frames method."""
536532
# Setup mocks
537533
mock_demux = Mock()
@@ -655,10 +651,7 @@ def test_get_next_frames_multiple_frames(self, _mock_torch: Any, mock_nvc: Any)
655651
@patch("nemo_curator.utils.nvcodec_utils.Nvc")
656652
@patch("nemo_curator.utils.nvcodec_utils.torch")
657653
@patch("nemo_curator.utils.nvcodec_utils.cvcuda")
658-
@patch("nemo_curator.utils.nvcodec_utils.nvcv")
659-
def test_generate_decoded_frames_with_frames(
660-
self, mock_nvcv: Any, mock_cvcuda: Any, mock_torch: Any, mock_nvc: Any
661-
) -> None:
654+
def test_generate_decoded_frames_with_frames(self, mock_cvcuda: Any, mock_torch: Any, mock_nvc: Any) -> None:
662655
"""Test generate_decoded_frames with actual frame processing."""
663656
# Setup mocks
664657
mock_demux = Mock()
@@ -677,20 +670,27 @@ def test_generate_decoded_frames_with_frames(
677670
mock_decoded_frame.nvcv_image.return_value = Mock()
678671

679672
# Mock tensor operations
680-
mock_nvcv_tensor = Mock()
681-
mock_nvcv_tensor.layout = "NCHW"
682-
mock_nvcv_tensor.shape = (1, 3, 480, 640) # NCHW format
683-
mock_nvcv.as_tensor.return_value = mock_nvcv_tensor
684-
mock_nvcv.as_image.return_value = Mock()
685-
mock_nvcv.Format.U8 = Mock()
673+
mock_cvcuda_tensor = Mock()
674+
mock_cvcuda_tensor.layout = "NCHW"
675+
mock_cvcuda_tensor.shape = (1, 3, 480, 640) # NCHW format
676+
mock_cvcuda.as_tensor.return_value = mock_cvcuda_tensor
677+
mock_cvcuda.as_image.return_value = Mock()
678+
mock_cvcuda.Format.U8 = Mock()
686679

687680
# Mock torch tensor
688681
mock_torch_nhwc = Mock()
689682
mock_torch.empty.return_value = mock_torch_nhwc
690683

691-
# Mock cvcuda tensor
684+
# Mock cvcuda tensor for NHWC conversion
692685
mock_cvcuda_nhwc = Mock()
693-
mock_cvcuda.as_tensor.return_value = mock_cvcuda_nhwc
686+
687+
# Setup side effect for as_tensor to return different mocks
688+
def as_tensor_side_effect(*args: Any, **_kwargs: Any) -> Any:
689+
if len(args) > 0 and hasattr(args[0], "cuda"):
690+
return mock_cvcuda_nhwc
691+
return mock_cvcuda_tensor
692+
693+
mock_cvcuda.as_tensor.side_effect = as_tensor_side_effect
694694

695695
# Mock demux iteration - return one packet with one frame
696696
mock_demux.__iter__ = Mock(return_value=iter([mock_packet]))
@@ -707,9 +707,8 @@ def test_generate_decoded_frames_with_frames(
707707
result = decoder.generate_decoded_frames()
708708

709709
# Verify frame processing was called
710-
mock_nvcv.as_tensor.assert_called_once()
710+
mock_cvcuda.as_tensor.assert_called()
711711
mock_torch.empty.assert_called_once()
712-
mock_cvcuda.as_tensor.assert_called_once()
713712
mock_cvcuda.reformat_into.assert_called_once()
714713

715714
# Should return the processed frames
@@ -718,9 +717,8 @@ def test_generate_decoded_frames_with_frames(
718717
@patch("nemo_curator.utils.nvcodec_utils.Nvc")
719718
@patch("nemo_curator.utils.nvcodec_utils.torch")
720719
@patch("nemo_curator.utils.nvcodec_utils.cvcuda")
721-
@patch("nemo_curator.utils.nvcodec_utils.nvcv")
722720
def test_generate_decoded_frames_unexpected_layout(
723-
self, mock_nvcv: Any, _mock_cvcuda: Any, _mock_torch: Any, mock_nvc: Any
721+
self, mock_cvcuda: Any, _mock_torch: Any, mock_nvc: Any
724722
) -> None:
725723
"""Test generate_decoded_frames with unexpected tensor layout."""
726724
# Setup mocks
@@ -740,11 +738,11 @@ def test_generate_decoded_frames_unexpected_layout(
740738
mock_decoded_frame.nvcv_image.return_value = Mock()
741739

742740
# Mock tensor with unexpected layout
743-
mock_nvcv_tensor = Mock()
744-
mock_nvcv_tensor.layout = "NHWC" # Unexpected layout - should be NCHW
745-
mock_nvcv.as_tensor.return_value = mock_nvcv_tensor
746-
mock_nvcv.as_image.return_value = Mock()
747-
mock_nvcv.Format.U8 = Mock()
741+
mock_cvcuda_tensor = Mock()
742+
mock_cvcuda_tensor.layout = "NHWC" # Unexpected layout - should be NCHW
743+
mock_cvcuda.as_tensor.return_value = mock_cvcuda_tensor
744+
mock_cvcuda.as_image.return_value = Mock()
745+
mock_cvcuda.Format.U8 = Mock()
748746

749747
# Mock demux iteration
750748
mock_demux.__iter__ = Mock(return_value=iter([mock_packet]))
@@ -764,10 +762,7 @@ def test_generate_decoded_frames_unexpected_layout(
764762
@patch("nemo_curator.utils.nvcodec_utils.Nvc")
765763
@patch("nemo_curator.utils.nvcodec_utils.torch")
766764
@patch("nemo_curator.utils.nvcodec_utils.cvcuda")
767-
@patch("nemo_curator.utils.nvcodec_utils.nvcv")
768-
def test_generate_decoded_frames_partial_batch(
769-
self, mock_nvcv: Any, mock_cvcuda: Any, mock_torch: Any, mock_nvc: Any
770-
) -> None:
765+
def test_generate_decoded_frames_partial_batch(self, mock_cvcuda: Any, mock_torch: Any, mock_nvc: Any) -> None:
771766
"""Test generate_decoded_frames with partial batch (less frames than batch_size)."""
772767
# Setup mocks
773768
mock_demux = Mock()
@@ -789,17 +784,24 @@ def test_generate_decoded_frames_partial_batch(
789784
mock_frames.append(frame)
790785

791786
# Mock tensor operations
792-
mock_nvcv_tensor = Mock()
793-
mock_nvcv_tensor.layout = "NCHW"
794-
mock_nvcv_tensor.shape = (1, 3, 480, 640)
795-
mock_nvcv.as_tensor.return_value = mock_nvcv_tensor
796-
mock_nvcv.as_image.return_value = Mock()
797-
mock_nvcv.Format.U8 = Mock()
787+
mock_cvcuda_tensor = Mock()
788+
mock_cvcuda_tensor.layout = "NCHW"
789+
mock_cvcuda_tensor.shape = (1, 3, 480, 640)
790+
mock_cvcuda.as_tensor.return_value = mock_cvcuda_tensor
791+
mock_cvcuda.as_image.return_value = Mock()
792+
mock_cvcuda.Format.U8 = Mock()
798793

799794
mock_torch_nhwc = Mock()
800795
mock_torch.empty.return_value = mock_torch_nhwc
801796
mock_cvcuda_nhwc = Mock()
802-
mock_cvcuda.as_tensor.return_value = mock_cvcuda_nhwc
797+
798+
# Setup side effect for as_tensor to return different mocks
799+
def as_tensor_side_effect(*args: Any, **_kwargs: Any) -> Any:
800+
if len(args) > 0 and hasattr(args[0], "cuda"):
801+
return mock_cvcuda_nhwc
802+
return mock_cvcuda_tensor
803+
804+
mock_cvcuda.as_tensor.side_effect = as_tensor_side_effect
803805

804806
# Mock demux to return frames across multiple packets then end
805807
decode_calls = []
@@ -825,9 +827,8 @@ def test_generate_decoded_frames_partial_batch(
825827
assert len(result) == 2
826828

827829
# Verify frame processing was called
828-
mock_nvcv.as_tensor.assert_called()
829-
mock_torch.empty.assert_called()
830830
mock_cvcuda.as_tensor.assert_called()
831+
mock_torch.empty.assert_called()
831832
mock_cvcuda.reformat_into.assert_called()
832833

833834

@@ -1178,7 +1179,7 @@ def test_error_messages_are_helpful(self) -> None:
11781179
def test_all_classes_can_be_imported(self) -> None:
11791180
"""Test that all public classes can be imported regardless of dependency availability."""
11801181
# All these should be importable even when dependencies are missing
1181-
from nemo_curator.utils.nvcodec_utils import (
1182+
from nemo_curator.utils.nvcodec_utils import ( # noqa: PLC0415
11821183
FrameExtractionPolicy,
11831184
NvVideoDecoder,
11841185
PyNvcFrameExtractor,

0 commit comments

Comments
 (0)