Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dali/python/backend_impl.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1609,8 +1609,8 @@ void ExposeTesorListGPU(py::module &m) {
R"code(
List of tensors residing in the GPU memory.
)code")
.def_static("broadcast", [](const Tensor<CPUBackend> &t, int num_samples) {
return std::make_shared<TensorList<CPUBackend>>(t, num_samples);
.def_static("broadcast", [](const Tensor<GPUBackend> &t, int num_samples) {
return std::make_shared<TensorList<GPUBackend>>(t, num_samples);
})
.def("as_cpu", [](TensorList<GPUBackend> &t) {
DeviceGuard g(t.device_id());
Expand Down
46 changes: 46 additions & 0 deletions dali/python/nvidia/dali/experimental/dynamic/_arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numbers
from typing import Any


def _implicitly_convertible(value: Any):
return isinstance(value, (numbers.Real, list, tuple))


def _arithm_op(name: str, *args):
from . import _arithmetic_generic_op
from ._batch import Batch
from ._tensor import Tensor, as_tensor

# scalar arguments are turned into tensors
argsstr = " ".join(f"&{i}" for i in range(len(args)))
gpu = any(arg.device.device_type == "gpu" for arg in args if isinstance(arg, (Tensor, Batch)))

new_args = []
for arg in args:
if not isinstance(arg, (Tensor, Batch)):
if gpu and _implicitly_convertible(arg):
arg = as_tensor(arg, device="gpu")
else:
arg = as_tensor(arg)

if (arg.device.device_type == "gpu") != gpu:
raise ValueError("Cannot mix GPU and CPU inputs.")

new_args.append(arg)

return _arithmetic_generic_op(*new_args, expression_desc=f"{name}({argsstr})")
33 changes: 2 additions & 31 deletions dali/python/nvidia/dali/experimental/dynamic/_batch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@
import nvtx

from . import _eval_mode, _invocation
from ._arithmetic import _arithm_op
from ._device import Device
from ._device import device as _device
from ._tensor import Tensor, _is_full_slice, _try_convert_enums
Expand Down Expand Up @@ -96,36 +97,6 @@ def __getitem__(self, ranges: Any) -> "Batch":
return _tensor_subscript(self._batch, **args)


def _arithm_op(name, *args, **kwargs):
from . import _arithmetic_generic_op

argsstr = " ".join(f"&{i}" for i in range(len(args)))
gpu = False
new_args = [None] * len(args)
for i, a in enumerate(args):
if isinstance(a, (Batch, Tensor)):
if a.device.device_type == "gpu":
gpu = True
else:
# TODO(michalz): We might use some caching here for common values.
if new_args is None:
new_args = list(args)
if gpu:
new_args[i] = _as_tensor(a, device="gpu")
else:
new_args[i] = _as_tensor(a)
if new_args[i].device.device_type == "gpu":
gpu = True

for i in range(len(args)):
if new_args[i] is None:
if (args[i].device.device_type == "gpu") != gpu:
raise ValueError("Cannot mix GPU and CPU inputs.")
new_args[i] = args[i]

return _arithmetic_generic_op(*new_args, expression_desc=f"{name}({argsstr})")


class _TensorList:
# `_TensorList` is what you get from `batch.tensors`.
# `_TensorList` is private because it's never meant to be constructed by the user and merely
Expand Down
12 changes: 3 additions & 9 deletions dali/python/nvidia/dali/experimental/dynamic/_tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,11 +15,12 @@
import copy
from typing import Any, Optional, SupportsInt, Tuple, Union

import numpy as np
import nvidia.dali.backend as _backend
import nvidia.dali.types
import numpy as np

from . import _eval_mode, _invocation
from ._arithmetic import _arithm_op
from ._device import Device
from ._device import device as _device
from ._eval_context import EvalContext as _EvalContext
Expand Down Expand Up @@ -632,13 +633,6 @@ def __rxor__(self, other):
return _arithm_op("bitxor", other, self)


def _arithm_op(name, *args, **kwargs):
argsstr = " ".join(f"&{i}" for i in range(len(args)))
from . import _arithmetic_generic_op

return _arithmetic_generic_op(*args, expression_desc=f"{name}({argsstr})")


def _is_int_value(tested: Any, reference: int) -> bool:
return isinstance(tested, int) and tested == reference

Expand Down
87 changes: 84 additions & 3 deletions dali/test/python/experimental_mode/test_arithm_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools

import numpy as np
import nvidia.dali.experimental.dynamic as ndd
from nose2.tools import params
import numpy as np
import itertools
from nose_utils import assert_raises, attr
from test_tensor import asnumpy


Expand Down Expand Up @@ -103,3 +105,82 @@ def test_unary_ops(device, op):
if not np.array_equal(asnumpy(y), ref_y):
msg = f"{ref_x} {op} = \n{asnumpy(y)}\n!=\n{ref_y}"
raise AssertionError(msg)


@params(*itertools.product(["gpu", "cpu"], binary_ops, (None, 4)))
def test_binary_scalars(device: str, op: str, batch_size: int | None):
tensors = [
np.array([[1, 2, 3], [4, 5, 6]]),
np.array([[1], [2], [3]]),
np.array([[1, 2, 3], [4, 5, 6]]),
]
scalars = [3, [4, 5, 6]]

for tensor, scalar in itertools.product(tensors, scalars):
if op == "/":
tensor = tensor.astype(np.float32)

if batch_size is None:
x = ndd.as_tensor(tensor, device=device)
else:
x = ndd.Batch.broadcast(tensor, batch_size=batch_size, device=device)

result = ndd.as_tensor(apply_bin_op(op, x, scalar))
result_rev = ndd.as_tensor(apply_bin_op(op, scalar, x))
ref = apply_bin_op(op, tensor, scalar)
ref_rev = apply_bin_op(op, scalar, tensor)

# np.allclose supports broadcasting
if not np.allclose(result.cpu(), ref):
msg = f"{tensor} {op} {scalar} = \n{result}\n!=\n{ref}"
raise AssertionError(msg)

if not np.allclose(result_rev.cpu(), ref_rev):
msg = f"{scalar} {op} {tensor} = \n{result_rev}\n!=\n{ref_rev}"
raise AssertionError(msg)


@attr("pytorch")
@params(*binary_ops)
def test_binary_pytorch_gpu(op: str):
import torch

a = torch.tensor([1, 2, 3], device="cuda")
b = ndd.as_tensor(a)

result = apply_bin_op(op, a, b)
result_rev = apply_bin_op(op, b, a)
expected = apply_bin_op(op, a, a)
np.testing.assert_array_equal(result.cpu(), expected.cpu())
np.testing.assert_array_equal(expected.cpu(), result_rev.cpu())


@params(*binary_ops)
def test_incompatible_devices(op: str):
a = ndd.tensor([1, 2, 3], device="cpu")
b = ndd.tensor([4, 5, 6], device="gpu")

with assert_raises(ValueError, regex="[CG]PU and [CG]PU"):
apply_bin_op(op, a, b)
with assert_raises(ValueError, regex="[CG]PU and [CG]PU"):
apply_bin_op(op, b, a)


@attr("pytorch")
@params(*binary_ops)
def test_binary_pytorch_incompatible(op: str):
import torch

devices = [
("cpu", "gpu"),
("cuda", "cpu"),
]

for torch_device, ndd_device in devices:
a = torch.tensor([1, 2, 3], device=torch_device)
b = ndd.tensor([1, 2, 3], device=ndd_device)

with assert_raises(ValueError, regex="[CG]PU and [CG]PU"):
apply_bin_op(op, a, b)
with assert_raises(ValueError, regex="[CG]PU and [CG]PU"):
apply_bin_op(op, b, a)