Skip to content

Commit 03f9013

Browse files
committed
Properly handle data that's already on the GPU
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
1 parent 9319fa9 commit 03f9013

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

dali/python/nvidia/dali/experimental/dynamic/_arithmetic.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,13 @@ def _arithm_op(name: str, *args):
3333
new_args = []
3434
for arg in args:
3535
if not isinstance(arg, (Tensor, Batch)):
36-
if gpu and not _implicitly_convertible(arg):
37-
raise ValueError(f"Type {type(arg)} is not implicitly copyable to the GPU.")
36+
if gpu and _implicitly_convertible(arg):
37+
arg = as_tensor(arg, device="gpu")
38+
arg = as_tensor(arg)
3839

39-
device = "gpu" if gpu else None
40-
arg = as_tensor(arg, device=device)
40+
if (arg.device.device_type == "gpu") != gpu:
41+
raise ValueError("Cannot mix GPU and CPU inputs.")
4142

4243
new_args.append(arg)
4344

44-
if any((arg.device.device_type == "gpu") != gpu for arg in new_args):
45-
raise ValueError("Cannot mix GPU and CPU inputs.")
46-
4745
return _arithmetic_generic_op(*new_args, expression_desc=f"{name}({argsstr})")

dali/test/python/experimental_mode/test_arithm_ops.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818
import nvidia.dali.experimental.dynamic as ndd
1919
from nose2.tools import params
20-
from nose_utils import assert_raises
20+
from nose_utils import assert_raises, attr
2121
from test_tensor import asnumpy
2222

2323

@@ -140,10 +140,47 @@ def test_binary_scalars(device: str, op: str, batch_size: int | None):
140140
raise AssertionError(msg)
141141

142142

143+
@attr("pytorch")
144+
@params(*binary_ops)
145+
def test_binary_pytorch_gpu(op: str):
146+
import torch
147+
148+
a = torch.tensor([1, 2, 3], device="cuda")
149+
b = ndd.as_tensor(a)
150+
151+
result = apply_bin_op(op, a, b)
152+
result_rev = apply_bin_op(op, b, a)
153+
expected = apply_bin_op(op, a, a)
154+
np.testing.assert_array_equal(result.cpu(), expected.cpu())
155+
np.testing.assert_array_equal(expected.cpu(), result.cpu())
156+
157+
143158
@params(*binary_ops)
144159
def test_incompatible_devices(op: str):
145160
a = ndd.tensor([1, 2, 3], device="cpu")
146161
b = ndd.tensor([4, 5, 6], device="gpu")
147162

148163
with assert_raises(ValueError, regex="[CG]PU and [CG]PU"):
149164
apply_bin_op(op, a, b)
165+
with assert_raises(ValueError, regex="[CG]PU and [CG]PU"):
166+
apply_bin_op(op, b, a)
167+
168+
169+
@attr("pytorch")
170+
@params(*binary_ops)
171+
def test_binary_pytorch_incompatible(op: str):
172+
import torch
173+
174+
devices = [
175+
("cpu", "gpu"),
176+
("cuda", "cpu"),
177+
]
178+
179+
for torch_device, ndd_device in devices:
180+
a = torch.tensor([1, 2, 3], device=torch_device)
181+
b = ndd.tensor([1, 2, 3], device=ndd_device)
182+
183+
with assert_raises(ValueError, regex="[CG]PU and [CG]PU"):
184+
apply_bin_op(op, a, b)
185+
with assert_raises(ValueError, regex="[CG]PU and [CG]PU"):
186+
apply_bin_op(op, b, a)

0 commit comments

Comments
 (0)