Skip to content

Commit 2de2e98

Browse files
Introducing CVCUDA Backend (pytorch#9259)
Summary: Summary ------- This PR provides the first building blocks for CV-CUDA integration in torchvision. We add the functionals `to_cvcuda_tensor` and `cvcuda_to_tensor` to transform from `torch.Tensor` to `cvcuda.Tensor` and back. We also implement the corresponding class transforms `ToCVCUDATensor` and `CVCUDAToTensor`. **Key features:** * **3-channel RGB support only**: Simplified API focusing on the most common use case (RGB images) * **Supported data types**: `torch.uint8` (RGB8 format) and `torch.float32` (RGBf32 format) * **Lossless round-trip conversions**: Exact data preservation when converting PyTorch ↔ CV-CUDA * **Informative error messages**: Helpful installation instructions when CV-CUDA is not available * **Batch-aware**: Handles both unbatched (CHW) and batched (NCHW) tensors Users must explicitly opt-in for these transforms, which require CV-CUDA to be installed. How to use ---------- ```python from PIL import Image import torchvision.transforms.v2.functional as F # Load and convert image to PyTorch tensor orig_img = Image.open("leaning_tower.jpg") img_tensor = F.pil_to_tensor(orig_img) # Convert to CV-CUDA tensor (must be 3-channel RGB on CUDA) cvcuda_tensor = F.to_cvcuda_tensor(img_tensor.cuda()) # Convert back to PyTorch tensor img_tensor = F.cvcuda_to_tensor(cvcuda_tensor) ``` > [!NOTE] > > * NVCV tensors are automatically converted to NHWC layout, contrary to torchvision's NCHW default > * Only 3-channel RGB images and 1-channel grayscale are supported for now > * Input tensors will be uploaded to CUDA device when converting to CV-CUDA tensors > * CV-CUDA must be installed: `pip install cvcuda-cu12` (CUDA 12) or `pip install cvcuda-cu11` (CUDA 11) Run unit tests -------------- ```bash pytest test/test_transforms_v2.py -k "cvcuda" ... 35 passed, 4 skipped, 9774 deselected in 1.12s ``` Differential Revision: D85862362 Pulled By: AntoineSimoulin
1 parent acccf86 commit 2de2e98

File tree

6 files changed

+310
-5
lines changed

6 files changed

+310
-5
lines changed

test/common_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
2121
from torchvision import io, tv_tensors
2222
from torchvision.transforms._functional_tensor import _max_value as get_max_value
23-
from torchvision.transforms.v2.functional import to_image, to_pil_image
23+
from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image
2424
from torchvision.utils import _Image_fromarray
2525

2626

@@ -400,6 +400,10 @@ def make_image_pil(*args, **kwargs):
400400
return to_pil_image(make_image(*args, **kwargs))
401401

402402

403+
def make_image_cvcuda(*args, **kwargs):
404+
return to_cvcuda_tensor(make_image(*args, **kwargs))
405+
406+
403407
def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):
404408
y = torch.randint(0, canvas_size[0], size=(num_points, 1), dtype=dtype, device=device)
405409
x = torch.randint(0, canvas_size[1], size=(num_points, 1), dtype=dtype, device=device)

test/test_transforms_v2.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
make_bounding_boxes,
3030
make_detection_masks,
3131
make_image,
32+
make_image_cvcuda,
3233
make_image_pil,
3334
make_image_tensor,
3435
make_keypoints,
@@ -51,8 +52,16 @@
5152
from torchvision.transforms.v2 import functional as F
5253
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
5354
from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs, _parallelogram_to_bounding_boxes
55+
from torchvision.transforms.v2.functional._type_conversion import _import_cvcuda_modules
5456
from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal
5557

58+
try:
59+
_import_cvcuda_modules()
60+
CVCUDA_AVAILABLE = True
61+
except ImportError:
62+
CVCUDA_AVAILABLE = False
63+
CUDA_AVAILABLE = torch.cuda.is_available()
64+
5665

5766
# turns all warnings into errors for this module
5867
pytestmark = [pytest.mark.filterwarnings("error")]
@@ -6733,6 +6742,125 @@ def test_functional_error(self):
67336742
F.pil_to_tensor(object())
67346743

67356744

6745+
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
6746+
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="test requires CUDA")
6747+
class TestToCVCUDATensor:
6748+
"""Tests for to_cvcuda_tensor function following patterns from TestToPil"""
6749+
6750+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64])
6751+
def test_1_channel_to_cvcuda_tensor(self, dtype):
6752+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6753+
if dtype in (torch.uint8, torch.uint16):
6754+
img_data = torch.randint(0, 256, (1, 4, 4), dtype=dtype)
6755+
else:
6756+
img_data = torch.rand(1, 4, 4, dtype=dtype)
6757+
img_data = img_data.cuda()
6758+
cvcuda_img = F.to_cvcuda_tensor(img_data)
6759+
assert cvcuda_img is not None
6760+
6761+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64])
6762+
def test_3_channel_to_cvcuda_tensor(self, dtype):
6763+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6764+
if dtype in (torch.uint8, torch.uint16):
6765+
img_data = torch.randint(0, 256, (3, 4, 4), dtype=dtype)
6766+
else:
6767+
img_data = torch.rand(3, 4, 4, dtype=dtype)
6768+
img_data = img_data.cuda()
6769+
cvcuda_img = F.to_cvcuda_tensor(img_data)
6770+
assert cvcuda_img is not None
6771+
6772+
def test_invalid_input_type(self):
6773+
with pytest.raises(TypeError, match=r"pic should be `torch.Tensor`"):
6774+
F.to_cvcuda_tensor("invalid_input")
6775+
6776+
def test_invalid_dimensions(self):
6777+
# Test 1D array (too few dimensions)
6778+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6779+
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
6780+
img_data = torch.randint(0, 256, (4,), dtype=torch.uint8)
6781+
img_data = img_data.cuda()
6782+
F.to_cvcuda_tensor(img_data)
6783+
6784+
# Test 2D array (no longer supported)
6785+
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
6786+
img_data = torch.randint(0, 256, (4, 4), dtype=torch.uint8)
6787+
img_data = img_data.cuda()
6788+
F.to_cvcuda_tensor(img_data)
6789+
6790+
# Test 5D array (too many dimensions)
6791+
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
6792+
img_data = torch.randint(0, 256, (1, 1, 3, 4, 4), dtype=torch.uint8)
6793+
img_data = img_data.cuda()
6794+
F.to_cvcuda_tensor(img_data)
6795+
6796+
@pytest.mark.parametrize("num_channels", [1, 3])
6797+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64])
6798+
def test_round_trip(self, num_channels, dtype):
6799+
# Setup: Create a tensor in CHW format (PyTorch standard)
6800+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6801+
if dtype in (torch.uint8, torch.uint16):
6802+
original_tensor = torch.randint(0, 256, (num_channels, 4, 4), dtype=dtype)
6803+
else:
6804+
original_tensor = torch.rand(num_channels, 4, 4, dtype=dtype)
6805+
original_tensor = original_tensor.cuda()
6806+
6807+
# Execute: Convert to CV-CUDA and back to tensor
6808+
# CHW -> (to_cvcuda_tensor) -> CV-CUDA NHWC -> (cvcuda_to_tensor) -> NCHW
6809+
cvcuda_tensor = F.to_cvcuda_tensor(original_tensor)
6810+
result_tensor = F.cvcuda_to_tensor(cvcuda_tensor)
6811+
6812+
# Remove batch dimension that was added during conversion since original was unbatched
6813+
result_tensor = result_tensor.squeeze(0)
6814+
6815+
# Assert: The round-trip conversion preserves the original tensor exactly
6816+
torch.testing.assert_close(result_tensor, original_tensor, rtol=0, atol=0)
6817+
6818+
@pytest.mark.parametrize("num_channels", [1, 3])
6819+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64])
6820+
@pytest.mark.parametrize("batch_size", [1, 2, 4])
6821+
def test_round_trip_batched(self, num_channels, dtype, batch_size):
6822+
# Setup: Create a batched tensor in NCHW format
6823+
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6824+
if dtype in (torch.uint8, torch.uint16):
6825+
original_tensor = torch.randint(0, 256, (batch_size, num_channels, 4, 4), dtype=dtype)
6826+
else:
6827+
original_tensor = torch.rand(batch_size, num_channels, 4, 4, dtype=dtype)
6828+
original_tensor = original_tensor.cuda()
6829+
6830+
# Execute: Convert to CV-CUDA and back to tensor
6831+
# NCHW -> (to_cvcuda_tensor) -> CV-CUDA NHWC -> (cvcuda_to_tensor) -> NCHW
6832+
cvcuda_tensor = F.to_cvcuda_tensor(original_tensor)
6833+
result_tensor = F.cvcuda_to_tensor(cvcuda_tensor)
6834+
6835+
# Assert: The round-trip conversion preserves the original batched tensor exactly
6836+
torch.testing.assert_close(result_tensor, original_tensor, rtol=0, atol=0)
6837+
# Also verify batch size is preserved
6838+
assert result_tensor.shape[0] == batch_size
6839+
6840+
6841+
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
6842+
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="test requires CUDA")
6843+
class TestCVDUDAToTensor:
6844+
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
6845+
@pytest.mark.parametrize(
6846+
"fn",
6847+
[F.cvcuda_to_tensor, transform_cls_to_functional(transforms.CVCUDAToTensor)],
6848+
)
6849+
def test_functional_and_transform(self, color_space, fn):
6850+
input = make_image_cvcuda(color_space=color_space)
6851+
6852+
output = fn(input)
6853+
6854+
assert isinstance(output, torch.Tensor)
6855+
# Convert input to tensor to compare sizes
6856+
input_tensor = F.cvcuda_to_tensor(input)
6857+
assert F.get_size(output) == F.get_size(input_tensor)
6858+
6859+
def test_functional_error(self):
6860+
with pytest.raises(TypeError, match="cvcuda_img should be `cvcuda.Tensor`"):
6861+
F.cvcuda_to_tensor(object())
6862+
6863+
67366864
class TestLambda:
67376865
@pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])
67386866
@pytest.mark.parametrize("types", [(), (torch.Tensor, np.ndarray)])

torchvision/transforms/v2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
ToDtype,
5656
)
5757
from ._temporal import UniformTemporalSubsample
58-
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
58+
from ._type_conversion import CVCUDAToTensor, PILToTensor, ToCVCUDATensor, ToImage, ToPILImage, ToPureTensor
5959
from ._utils import check_type, get_bounding_boxes, get_keypoints, has_all, has_any, query_chw, query_size
6060

6161
from ._deprecated import ToTensor # usort: skip

torchvision/transforms/v2/_type_conversion.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from torchvision import tv_tensors
88
from torchvision.transforms.v2 import functional as F, Transform
9-
109
from torchvision.transforms.v2._utils import is_pure_tensor
10+
from torchvision.utils import _log_api_usage_once
1111

1212

1313
class PILToTensor(Transform):
@@ -90,3 +90,71 @@ class ToPureTensor(Transform):
9090

9191
def transform(self, inpt: Any, params: dict[str, Any]) -> torch.Tensor:
9292
return inpt.as_subclass(torch.Tensor)
93+
94+
95+
class ToCVCUDATensor:
96+
"""Convert a torch.Tensor to cvcuda.Tensor
97+
98+
This transform does not support torchscript.
99+
100+
Converts a torch.*Tensor of shape C x H x W to a cvcuda.Tensor.
101+
Only 1-channel and 3-channel images are supported.
102+
"""
103+
104+
def __init__(self):
105+
_log_api_usage_once(self)
106+
107+
def __call__(self, pic):
108+
"""
109+
Args:
110+
pic (torch.Tensor): Image to be converted to cvcuda.Tensor.
111+
112+
Returns:
113+
cvcuda.Tensor: Image converted to cvcuda.Tensor.
114+
115+
"""
116+
return F.to_cvcuda_tensor(pic)
117+
118+
def __repr__(self) -> str:
119+
return f"{self.__class__.__name__}()"
120+
121+
122+
class CVCUDAToTensor:
123+
"""Convert a `cvcuda.Tensor` to a `torch.Tensor` of the same type - this does not scale values.
124+
125+
This transform does not support torchscript.
126+
127+
Converts a `cvcuda.Tensor` to a `torch.Tensor`. Supports both batched and unbatched inputs:
128+
- Unbatched: (H, W, C) or (H, W) → (C, H, W) or (1, H, W)
129+
- Batched: (N, H, W, C) or (N, H, W) → (N, C, H, W) or (N, 1, H, W)
130+
131+
The conversion happens directly on GPU when the `cvcuda.Tensor` is stored on GPU,
132+
avoiding unnecessary data transfers.
133+
134+
Example:
135+
>>> import cvcuda
136+
>>> import torchvision.transforms.v2 as T
137+
>>> # Create a CV-CUDA tensor (320x240 RGB)
138+
>>> # Note: In CV-CUDA 0.16.0+, Image/Tensor creation uses cvcuda module
139+
>>> img_tensor = torch.randint(0, 255, (1, 240, 320, 3), dtype=torch.uint8, device="cuda")
140+
>>> cvcuda_tensor = cvcuda.as_tensor(img_tensor, cvcuda.TensorLayout.NHWC)
141+
>>> tensor = T.CVCUDAToTensor()(cvcuda_tensor)
142+
>>> print(tensor.shape)
143+
torch.Size([1, 3, 240, 320])
144+
"""
145+
146+
def __init__(self) -> None:
147+
_log_api_usage_once(self)
148+
149+
def __call__(self, pic):
150+
"""
151+
Args:
152+
pic (cvcuda.Tensor): CV-CUDA Tensor to be converted to tensor.
153+
154+
Returns:
155+
Tensor: Converted image in CHW format.
156+
"""
157+
return F.cvcuda_to_tensor(pic)
158+
159+
def __repr__(self) -> str:
160+
return f"{self.__class__.__name__}()"

torchvision/transforms/v2/functional/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,6 @@
162162
to_dtype_video,
163163
)
164164
from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video
165-
from ._type_conversion import pil_to_tensor, to_image, to_pil_image
165+
from ._type_conversion import cvcuda_to_tensor, pil_to_tensor, to_cvcuda_tensor, to_image, to_pil_image
166166

167167
from ._deprecated import get_image_size, to_tensor # usort: skip

torchvision/transforms/v2/functional/_type_conversion.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,36 @@
1-
from typing import Union
1+
from typing import TYPE_CHECKING, Union
22

33
import numpy as np
44
import PIL.Image
55
import torch
66
from torchvision import tv_tensors
77
from torchvision.transforms import functional as _F
8+
from torchvision.utils import _log_api_usage_once
9+
10+
if TYPE_CHECKING:
11+
import cvcuda # type: ignore[import-not-found]
12+
13+
14+
def _import_cvcuda_modules():
15+
"""Import CV-CUDA modules with informative error message if not installed.
16+
17+
Returns:
18+
cvcuda module.
19+
20+
Raises:
21+
RuntimeError: If CV-CUDA is not installed.
22+
"""
23+
try:
24+
import cvcuda # type: ignore[import-not-found]
25+
26+
return cvcuda
27+
except ImportError as e:
28+
raise ImportError(
29+
"CV-CUDA is required but not installed. "
30+
"Please install it following the instructions at "
31+
"https://github.com/CVCUDA/CV-CUDA or via pip: "
32+
"`pip install cvcuda-cu12` (for CUDA 12) or `pip install cvcuda-cu11` (for CUDA 11)."
33+
) from e
834

935

1036
@torch.jit.unused
@@ -25,3 +51,82 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tenso
2551

2652
to_pil_image = _F.to_pil_image
2753
pil_to_tensor = _F.pil_to_tensor
54+
55+
56+
@torch.jit.unused
57+
def to_cvcuda_tensor(pic) -> "cvcuda.Tensor":
58+
"""Convert a torch.Tensor to cvcuda.Tensor. This function does not support torchscript.
59+
60+
See :class:`~torchvision.transforms.v2.ToCVCUDATensor` for more details.
61+
62+
Args:
63+
pic (torch.Tensor): Image to be converted to cvcuda.Tensor.
64+
Tensor can be in CHW format (unbatched) or NCHW format (batched).
65+
Only 1-channel and 3-channel images are supported.
66+
67+
Returns:
68+
cvcuda.Tensor: Image converted to cvcuda.Tensor with NHWC layout.
69+
"""
70+
cvcuda = _import_cvcuda_modules()
71+
72+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
73+
_log_api_usage_once(to_cvcuda_tensor)
74+
75+
# Validate input type
76+
if not isinstance(pic, torch.Tensor):
77+
raise TypeError(f"pic should be `torch.Tensor`. Got {type(pic)}.")
78+
79+
# Validate dimensions - only support 3D (CHW) or 4D (NCHW)
80+
if pic.ndim == 3:
81+
# Add fake batch dimension to make it 4D
82+
img_tensor = pic.unsqueeze(0)
83+
elif pic.ndim == 4:
84+
img_tensor = pic
85+
else:
86+
raise ValueError(f"pic should be 3 or 4 dimensional. Got {pic.ndim} dimensions.")
87+
88+
# Convert NCHW -> NHWC
89+
img_tensor = img_tensor.permute(0, 2, 3, 1)
90+
91+
# Convert to CV-CUDA tensor with NHWC layout
92+
return cvcuda.as_tensor(img_tensor.cuda().contiguous(), cvcuda.TensorLayout.NHWC)
93+
94+
95+
@torch.jit.unused
96+
def cvcuda_to_tensor(cvcuda_img: "cvcuda.Tensor") -> torch.Tensor:
97+
"""Convert a cvcuda.Tensor to a PyTorch tensor. This function does not support torchscript.
98+
99+
Args:
100+
cvcuda_img (cvcuda.Tensor): cvcuda.Tensor to be converted to PyTorch tensor.
101+
Expected to be in NHWC or NHW layout (batched images only).
102+
103+
Returns:
104+
torch.Tensor: Converted image in NCHW format (batched).
105+
"""
106+
cvcuda = _import_cvcuda_modules()
107+
108+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
109+
_log_api_usage_once(cvcuda_to_tensor)
110+
111+
# Validate input type
112+
if not isinstance(cvcuda_img, cvcuda.Tensor):
113+
raise TypeError(f"cvcuda_img should be `cvcuda.Tensor`. Got {type(cvcuda_img)}.")
114+
115+
# Convert CV-CUDA Tensor to PyTorch tensor via CUDA array interface
116+
# CV-CUDA tensors expose __cuda_array_interface__ which PyTorch can consume directly
117+
cuda_tensor = torch.as_tensor(cvcuda_img.cuda(), device="cuda")
118+
119+
# Only support 4D (NHWC) or 3D (NHW) batched tensors
120+
# CV-CUDA stores images in NHWC (batched multi-channel) or NHW (batched single-channel) format
121+
if cuda_tensor.ndim == 4:
122+
# Batched multi-channel image in NHWC format
123+
# Convert NHWC -> NCHW
124+
img = cuda_tensor.permute(0, 3, 1, 2)
125+
elif cuda_tensor.ndim == 3:
126+
# Batched single-channel image in NHW format
127+
# Convert NHW -> NCHW by adding channel dimension
128+
img = cuda_tensor.unsqueeze(1)
129+
else:
130+
raise ValueError(f"Image should be 3 or 4 dimensional. Got {cuda_tensor.ndim} dimensions.")
131+
132+
return img

0 commit comments

Comments
 (0)