|
17 | 17 | import numpy as np |
18 | 18 | import nvidia.dali.experimental.dynamic as ndd |
19 | 19 | from nose2.tools import params |
20 | | -from nose_utils import assert_raises |
| 20 | +from nose_utils import assert_raises, attr |
21 | 21 | from test_tensor import asnumpy |
22 | 22 |
|
23 | 23 |
|
@@ -140,10 +140,47 @@ def test_binary_scalars(device: str, op: str, batch_size: int | None): |
140 | 140 | raise AssertionError(msg) |
141 | 141 |
|
142 | 142 |
|
| 143 | +@attr("pytorch") |
| 144 | +@params(*binary_ops) |
| 145 | +def test_binary_pytorch_gpu(op: str): |
| 146 | + import torch |
| 147 | + |
| 148 | + a = torch.tensor([1, 2, 3], device="cuda") |
| 149 | + b = ndd.as_tensor(a) |
| 150 | + |
| 151 | + result = apply_bin_op(op, a, b) |
| 152 | + result_rev = apply_bin_op(op, b, a) |
| 153 | + expected = apply_bin_op(op, a, a) |
| 154 | + np.testing.assert_array_equal(result.cpu(), expected.cpu()) |
| 155 | + np.testing.assert_array_equal(expected.cpu(), result.cpu()) |
| 156 | + |
| 157 | + |
143 | 158 | @params(*binary_ops) |
144 | 159 | def test_incompatible_devices(op: str): |
145 | 160 | a = ndd.tensor([1, 2, 3], device="cpu") |
146 | 161 | b = ndd.tensor([4, 5, 6], device="gpu") |
147 | 162 |
|
148 | 163 | with assert_raises(ValueError, regex="[CG]PU and [CG]PU"): |
149 | 164 | apply_bin_op(op, a, b) |
| 165 | + with assert_raises(ValueError, regex="[CG]PU and [CG]PU"): |
| 166 | + apply_bin_op(op, b, a) |
| 167 | + |
| 168 | + |
| 169 | +@attr("pytorch") |
| 170 | +@params(*binary_ops) |
| 171 | +def test_binary_pytorch_incompatible(op: str): |
| 172 | + import torch |
| 173 | + |
| 174 | + devices = [ |
| 175 | + ("cpu", "gpu"), |
| 176 | + ("cuda", "cpu"), |
| 177 | + ] |
| 178 | + |
| 179 | + for torch_device, ndd_device in devices: |
| 180 | + a = torch.tensor([1, 2, 3], device=torch_device) |
| 181 | + b = ndd.tensor([1, 2, 3], device=ndd_device) |
| 182 | + |
| 183 | + with assert_raises(ValueError, regex="[CG]PU and [CG]PU"): |
| 184 | + apply_bin_op(op, a, b) |
| 185 | + with assert_raises(ValueError, regex="[CG]PU and [CG]PU"): |
| 186 | + apply_bin_op(op, b, a) |
0 commit comments