Skip to content

Torchvision API infrastructure#6229

Merged
mdabek-nvidia merged 11 commits intoNVIDIA:mainfrom
mdabek-nvidia:torchvision_api_infra
Mar 18, 2026
Merged

Torchvision API infrastructure#6229
mdabek-nvidia merged 11 commits intoNVIDIA:mainfrom
mdabek-nvidia:torchvision_api_infra

Conversation

@mdabek-nvidia
Copy link
Collaborator

Torchvision API infrastructure

Category:

New feature

Description:

This is the first of the series of PRs implementing Torchvision API.
The full PR has been modified to include only infrastructure, resize and flip operators required for unit testing.

Additional information:

Affected modules and functionalities:

Key points relevant for the review:

Tests:

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: N/A

Initial commit including infrastructure, resize and flip operators and unit tests

Signed-off-by: Marek Dabek <mdabek@nvidia.com>
@mdabek-nvidia
Copy link
Collaborator Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [44795470]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [44795470]: BUILD PASSED

@mdabek-nvidia mdabek-nvidia marked this pull request as ready for review February 26, 2026 11:50
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 26, 2026

Greptile Summary

This PR introduces the foundational infrastructure for a DALI-backed Torchvision-compatible transforms API (nvidia.dali.experimental.torchvision), covering the Compose pipeline orchestrator, Resize, RandomHorizontalFlip, and RandomVerticalFlip operators, and a functional sub-API — along with comprehensive Python tests. The implementation routes PIL Image inputs through an HWC pipeline and torch.Tensor inputs through a CHW pipeline, bridging DALI's internal data model to Torchvision's user-facing conventions.

Key points from the review:

  • Logic bug in compose.py (line 199): _internal_run is set to _cuda_run whenever torch.cuda.is_available() is true, even for CPU pipelines. This means CPU-device pipelines pass a CUDA stream to pipe.run() on machines with a GPU, which is semantically incorrect. The condition should be self.device == "gpu".
  • Typo in operator.py (line 992): Warning message says "copyig" instead of "copying".
  • Most previously-raised issues (import path fragility, **kwargs unpacking, CHW dimension ordering, unsupported interpolation mode guard, functools.wraps, CUDA stream device selection) have been addressed in this revision.
  • The dual-mode sizing logic (calculate_target_size_pipeline_mode / calculate_target_size_dynamic_mode) is well-structured; the autograph registration in _conditionals.py ensures DALI's conditional handling applies to the new module.

Confidence Score: 3/5

  • Safe to merge after fixing the _cuda_run / _cpu_run selection logic for CPU pipelines on CUDA machines.
  • The PR is well-tested and most previously-flagged issues have been resolved. One confirmed logic error remains in compose.py line 199: CPU pipelines incorrectly select _cuda_run when CUDA hardware is present, which passes a CUDA stream to a CPU DALI pipeline. While this may be benign in practice (DALI may ignore the stream for CPU pipelines), it is semantically wrong and could produce subtle misbehaviour in edge cases. A one-line fix resolves it.
  • dali/python/nvidia/dali/experimental/torchvision/v2/compose.py (line 199, _internal_run device selection)

Important Files Changed

Filename Overview
dali/python/nvidia/dali/experimental/torchvision/v2/compose.py Core pipeline infrastructure: routes PIL/Tensor inputs through HWC or CHW DALI pipelines. Contains a logic error where _cuda_run is selected for CPU pipelines whenever CUDA hardware is present (should be conditioned on self.device == "gpu"). Also retains dead if output is None guard from a prior review cycle.
dali/python/nvidia/dali/experimental/torchvision/v2/operator.py Base Operator class, verification rules, and the adjust_input decorator. Contains a minor typo in a warning message ("copyig" → "copying"). Logic is otherwise solid; type(self).verify_data correctly dispatches to subclass rules.
dali/python/nvidia/dali/experimental/torchvision/v2/resize.py Resize operator with dual pipeline/dynamic mode sizing logic. Layout detection in get_inputHW correctly handles NHWC/NCHW/HWC/CHW by stripping the leading 'N' before the single-character branch. VerificationSize allows max_size=0 (uses < 0 rather than <= 0) but that was addressed in a prior review thread.
dali/python/nvidia/dali/experimental/torchvision/v2/functional/resize.py Functional resize API. Uses proper relative imports. Layout handling correctly extracts H/W for HWC, NHWC, CHW, and NCHW, with an explicit ValueError for unrecognized layouts. not_supported_interpolation_modes guard is present.
dali/test/python/torchvision/test_tv_compose.py Comprehensive compose tests covering PIL/Tensor paths, multi-op pipelines, CUDA stream ordering, dtype preservation, and mode consistency. Good coverage of edge cases.
dali/test/python/torchvision/test_tv_resize.py Resize tests covering sizes, max_size, interpolation modes, antialiasing, large sizes, and dtype handling. The shape[1:3] comparison in _internal_loop is only valid for the 4D NHWC path through PIL, and for raw CHW/NCHW tensors it compares the wrong dimensions (previously flagged), but is otherwise thorough.

Sequence Diagram

sequenceDiagram
    participant User
    participant Compose
    participant PipelineHWC / PipelineCHW
    participant PipelineWithLayout
    participant DALI Pipeline (_pipeline_function)
    participant Operator (_kernel)

    User->>Compose: __call__(data_input)
    Compose->>Compose: VerificationTensorOrImage.verify(data_input)
    alt first call
        Compose->>Compose: _build_pipeline(data_input)
        Note over Compose: PIL Image → PipelineHWC<br/>torch.Tensor → PipelineCHW
    end
    Compose->>PipelineHWC / PipelineCHW: run(data_input)
    PipelineHWC / PipelineCHW->>PipelineHWC / PipelineCHW: convert / unsqueeze input
    PipelineHWC / PipelineCHW->>PipelineWithLayout: run(tensor)
    PipelineWithLayout->>PipelineWithLayout: _align_data_with_device(tensor)
    alt device == "gpu" (and CUDA available)
        PipelineWithLayout->>DALI Pipeline (_pipeline_function): _cuda_run(tensor) via current CUDA stream
    else
        PipelineWithLayout->>DALI Pipeline (_pipeline_function): _cpu_run(tensor)
    end
    loop for each op in op_list
        DALI Pipeline (_pipeline_function)->>Operator (_kernel): op(DataNode)
        Operator (_kernel)-->>DALI Pipeline (_pipeline_function): transformed DataNode
    end
    DALI Pipeline (_pipeline_function)-->>PipelineWithLayout: TensorList output
    PipelineWithLayout->>PipelineWithLayout: to_torch_tensor(output)
    PipelineWithLayout-->>PipelineHWC / PipelineCHW: torch.Tensor
    PipelineHWC / PipelineCHW->>PipelineHWC / PipelineCHW: convert back (PIL Image / squeeze)
    PipelineHWC / PipelineCHW-->>User: PIL Image or torch.Tensor
Loading

Comments Outside Diff (1)

  1. dali/python/nvidia/dali/experimental/torchvision/v2/compose.py, line 199 (link)

    CPU pipelines incorrectly use _cuda_run when CUDA is available

    torch.cuda.is_available() checks for hardware presence, but whether to use CUDA stream handling should depend on whether the pipeline is a GPU pipeline. When self.device == "cpu" on a CUDA machine, _cuda_run is selected, which creates a CUDA stream and passes it to self.pipe.run(stream, ...) for a CPU pipeline. This is semantically incorrect and may cause unexpected behavior (DALI receiving a CUDA stream for a CPU-mode pipeline).

    The condition should be based on the pipeline's own device:

Last reviewed commit: "Review fixes"

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 7 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +124 to +127
stream = torch.cuda.Stream(0)
with torch.cuda.stream(stream):
output = self.pipe.run(stream, input_data=data_input)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a cuda stream unconditionally might break CPU-only environments. Please wrap it with torch.cuda.is_available() or similar

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have CPU-only tests that will cover torchvision? If not, we should definitely add some to catch things like this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are CPU-only test cases, but no test suite. I will add it.

Comment on lines +58 to +60
if len(tensor_or_tl) == 1:
tensor_or_tl = tensor_or_tl[0]
return _to_torch_tensor(tensor_or_tl)
Copy link
Contributor

@jantonguirao jantonguirao Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we what to unpack 1 element tuples?. How about something like:

def to_torch_tensor(
    x: Union[tuple, 'TensorListGPU', 'TensorListCPU']
) -> Union[torch.Tensor, tuple]:
    if isinstance(x, (TensorListGPU, TensorListCPU)):
        return to_torch_tensor(x.as_tensor())
    elif isinstance(x, tuple):
        return tuple(to_torch_tensor(elem) for elem in x)
    else:
        return torch.from_dlpack(x)

Copy link
Collaborator Author

@mdabek-nvidia mdabek-nvidia Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to unpack it since Torchvision Compose operator returns tensor or PIL.Image not tuple. DALI.Torchvision supposed to be used as a drop in replacement to Torchvision and maximum compatibility should be maintained.
I will use the above code as a base for better implementation of to_torch_tensor.

Comment on lines +283 to +285
# This is WAR for DLPpack not supporting pinned memory
if output.device.device_type == "cpu":
output = np.asarray(output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should at least document this, or add a warning message

VerificationTensorOrImage.verify(data_input)

if self.active_pipeline is None:
self._build_pipeline(data_input)
Copy link
Contributor

@jantonguirao jantonguirao Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should verify if we built a CHW or HWC pipeline, and rebuild in case the type changes. We can store the "pipeline kind" as a member, and initialize it during _build_pipeline.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't changing the layout in the middle of pipeline lifecycle considered a bug?

[DEPRECATED but used]
"""

def __call__(self, data_input):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't seem to cover fully the torchvision core semantics. Here's a suggestion (LLM generated):

class ToTensor:
    """
    Convert an image-like tensor to a float32 CHW tensor in [0, 1].

    Intended to be analogous to torchvision's ToTensor for the common image path:
    - input is uint8 in [0, 255] with layout "HWC" (e.g. PIL image converted to array)
    - output is float32 in [0, 1] with layout "CHW"

    Notes
    -----
    - This operates on DALI graph nodes (DataNode).
    - If the input is already floating point, it is only cast to float32 by default
      and not rescaled (to avoid double-scaling). Set `scale=True` if you want
      unconditional division by 255.
    """

    def __init__(
        self,
        *,
        output_layout: Literal["CHW", "HWC"] = "CHW",
        dtype=types.FLOAT,
        scale: bool = True,
        input_layout: Optional[Literal["HWC", "CHW"]] = None,
    ):
        """
        Parameters
        ----------
        output_layout:
            Layout of the output. torchvision uses CHW.
        dtype:
            Output dtype (default: FLOAT / float32).
        scale:
            If True, divide by 255 when input is integer type. If False, only cast/transpose.
        input_layout:
            If provided, forces how ToTensor interprets the input layout.
            If None, ToTensor assumes HWC (typical for PIL path in your Compose).
        """
        self.output_layout = output_layout
        self.dtype = dtype
        self.scale = scale
        self.input_layout = input_layout

    def __call__(self, data_input):
        # 1) Convert to float32 (or requested dtype)
        # If input is integer and scale=True, do scaling in float to get [0, 1].
        x = data_input

        # Cast first; DALI will handle conversion
        x = fn.cast(x, dtype=self.dtype)

        if self.scale:
            # This matches torchvision's ToTensor scaling for uint8-like images.
            # We do it unconditionally after cast; for non-uint8 integer inputs it still scales,
            # which is usually desired for "image" semantics. If you need stricter behavior,
            # you can make this conditional on original dtype (requires dtype introspection).
            x = x / 255.0

        # 2) Layout conversion
        in_layout = self.input_layout or "HWC"
        out_layout = self.output_layout

        if in_layout == out_layout:
            return x

        if in_layout == "HWC" and out_layout == "CHW":
            # permute HWC -> CHW
            x = fn.transpose(x, perm=[2, 0, 1])
            return x

        if in_layout == "CHW" and out_layout == "HWC":
            x = fn.transpose(x, perm=[1, 2, 0])
            return x

        raise ValueError(f"Unsupported layout conversion: {in_layout} -> {out_layout}")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ToTensor is a special case in this implementation, since DALI is returning tensors from a pipeline. I agree that scaling should be optional and casting is missing.
What is more, DALI implementation is limited and allows using ToTensor as a final operator of the pipeline - this will be documented.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed ToTensor fromt his PR

Comment on lines +127 to +131
if no_size:
if orig_h > orig_w:
target_w = (max_size * orig_w) / orig_h
else:
target_h = (max_size * orig_h) / orig_w
Copy link
Contributor

@jantonguirao jantonguirao Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works in float, which might be imcompatible with torchvision (can lead to error off by one, etc).

Here's a suggestion (LLM generated, to be double checked) following the torch rounding rules:

from __future__ import annotations

from typing import Optional, Sequence, Tuple, Union
import math

SizeArg = Union[int, Sequence[int], None]

def _check_size_arg(size: SizeArg) -> SizeArg:
    # torchvision allows int or sequence len 1 or 2; len 1 treated like int
    if isinstance(size, (list, tuple)):
        if len(size) == 0 or len(size) > 2:
            raise ValueError(f"size must have len 1 or 2, got {len(size)}")
        if len(size) == 1:
            return int(size[0])
        return (int(size[0]), int(size[1]))
    if size is None:
        return None
    return int(size)

def _round_to_int(x: float) -> int:
    # torchvision uses rounding-to-nearest-int for computed sizes.
    # (In practice they do `int(x + 0.5)` for positive x or `round(x)` depending on version.)
    return int(math.floor(x + 0.5))

def torchvision_resize_output_size(
    orig_h: int,
    orig_w: int,
    size: SizeArg,
    max_size: Optional[int] = None,
) -> Tuple[int, int]:
    """
    Compute output size for torchvision.transforms.Resize(size, max_size=max_size)
    for 2D images.

    Returns (new_h, new_w) as ints.
    """
    size = _check_size_arg(size)

    if size is None:
        if max_size is None:
            raise ValueError("If size is None, max_size must be provided.")
        # Equivalent to "not_larger": longest edge becomes max_size
        # and aspect ratio is preserved.
        # This mirrors torchvision v2 functional behavior.
        if orig_w >= orig_h:
            new_w = int(max_size)
            new_h = _round_to_int(max_size * orig_h / orig_w)
        else:
            new_h = int(max_size)
            new_w = _round_to_int(max_size * orig_w / orig_h)
        return new_h, new_w

    # size is (h, w): direct
    if isinstance(size, tuple):
        if max_size is not None:
            raise ValueError("max_size must be None when size is a sequence of length 2.")
        return int(size[0]), int(size[1])

    # size is int: match shorter edge
    if max_size is not None and max_size <= 0:
        raise ValueError("max_size must be positive.")

    short, long_ = (orig_w, orig_h) if orig_w <= orig_h else (orig_h, orig_w)

    if short == size:
        # torchvision returns original size (no-op)
        new_h, new_w = orig_h, orig_w
    else:
        scale = float(size) / float(short)
        new_h = _round_to_int(orig_h * scale)
        new_w = _round_to_int(orig_w * scale)

    if max_size is not None:
        new_short, new_long = (new_w, new_h) if new_w <= new_h else (new_h, new_w)
        if new_long > max_size:
            scale = float(max_size) / float(new_long)
            new_h = _round_to_int(new_h * scale)
            new_w = _round_to_int(new_w * scale)

    # safety: clamp to at least 1
    new_h = max(1, int(new_h))
    new_w = max(1, int(new_w))
    return new_h, new_w

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added size calculation in Python

Comment on lines +102 to +103
tv_shape_lower = torch.Size([out_tv.shape[1] - 1, out_tv.shape[2] - 1])
tv_shape_upper = torch.Size([out_tv.shape[1] + 1, out_tv.shape[2] + 1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead, you could follow torch's rounding rules and expect equality? At least to be checked if it is possible to achieve with DALI.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strongly agree, full equality would be the best.

If the rounding errors are unavoidable, we should probably document where discrepancy comes from and make sure that we know the error bound (is it always +/-1? or can it be +/-2 in some scenarios?). The documentation should warn the users that they won't see bit-exact results.

If equality is not possible, I think using something like isclose would make this test code more readable, as you'll have something along the lines of assert torch.isclose(out_fn.shape[0], out_tv.shape[0], rtol=0, atol=1)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like it to be equal as well, but DALI calculates the output size different from Torchvision.
The following change in the resize_attr_base.cc::AdjustOutputSize:

-          out_size[d] = in_size[d] * scale[d];
+          out_size[d] = std::floor(in_size[d] * scale[d]);

would align both libraries, but it would change DALI's behavior.
There is an option to recalculate the output size in Python, but it would add an additional overhead to resize and I took a liberty to not do that and add it to the documentation that the calcualted size may be off by one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added size calculation in Python

Copy link
Collaborator

@szkarpinski szkarpinski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving some comments after the first reading. My main concerns are:

  • There's a significant amount of hard-coded dependencies on layouts (HWC, CHW) or color formats (RGB, RGBA). I'd like to make sure that this is intentional and we don't anticipate any need to extend this soon.
  • The verification of arguments and argument handling seems to overlap with what DALI already does, and duplicated input validation means added maintenance cost and a risk of discrepancies

return torch.from_dlpack(dali_tensor)


def to_torch_tensor(tensor_or_tl: tuple | TensorListGPU | TensorListCPU) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a useful and generic thing, not strictly related to torchvision. Shouldn't this be a part of main DALI, or Torch plugin?

Comment on lines +124 to +127
stream = torch.cuda.Stream(0)
with torch.cuda.stream(stream):
output = self.pipe.run(stream, input_data=data_input)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have CPU-only tests that will cover torchvision? If not, we should definitely add some to catch things like this

raise ValueError(f"Values {name} should be positive number, got {values}")


class VerifyIfOrderedPair(ArgumentVerificationRule):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: ordered pair is something different

@params("gpu", "cpu")
def test_horizontal_random_flip_probability(device):
img = make_test_tensor()
transform = Compose([RandomHorizontalFlip(p=1.0, device=device)]) # always flip
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is Compose required to instantiate an operator? I believe that in torchvision you can use Compose to compose multiple operators, but Compose on a single-element list is a no-op and is not required nor encouraged

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is Compose required to instantiate an operator?

Yes, since Compose is a wrapper for a pipeline. It does not make much sense to instantiate standalone operators (and build a pipeline around them), since it would promote an antipattern and probably hinder the performance. It is better to use functional, which is using dynamic mode.

Torchvision allows standalone operators instantiation, since they have low overhead when executed on CPU.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And what happens when a user tries to use an operator without Compose (which is allowed in torchvision)? Will the error message be meaningful enough to guide the user towards the correct usage?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will end up with an unrelated DALI exception. I don't think that there is an obvious way to check if operator is inside a pipeline, other than passing a Compose to pipeline's operators.
It plan to add a note in the documentation, that DALI Torchvision operators needs to be encompassed with Compose.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I underestimate the users, but I'm afraid very few will see this note and will get discouraged by the unrelated DALI error - "if writing a simplest single-operator Torchvision transform results in a cryptic error, what's next?!".

I'd strongly suggest wrapping the Torchvision API DataNodes into some thin Python layer that would return a meaningful error message. Or, instead of wrapping, maybe we can overwrite the __call__ of the operator.

If you believe it's not the right moment to do this we can do this in a follow-up.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd echo what Szymon says about possible false impression that "this barely works" if one enounters such an error. The idea of overriding __call__ operator in experimental.trochvision.functional feels promising - the __call__ just throws, and we could have a private _call that calls super.

Comment on lines +102 to +103
tv_shape_lower = torch.Size([out_tv.shape[1] - 1, out_tv.shape[2] - 1])
tv_shape_upper = torch.Size([out_tv.shape[1] + 1, out_tv.shape[2] + 1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strongly agree, full equality would be the best.

If the rounding errors are unavoidable, we should probably document where discrepancy comes from and make sure that we know the error bound (is it always +/-1? or can it be +/-2 in some scenarios?). The documentation should warn the users that they won't see bit-exact results.

If equality is not possible, I think using something like isclose would make this test code more readable, as you'll have something along the lines of assert torch.isclose(out_fn.shape[0], out_tv.shape[0], rtol=0, atol=1)

Signed-off-by: Marek Dabek <mdabek@nvidia.com>
@mdabek-nvidia
Copy link
Collaborator Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45326301]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45326301]: BUILD FAILED

@mdabek-nvidia mdabek-nvidia force-pushed the torchvision_api_infra branch from 34bcd56 to 08bbd62 Compare March 5, 2026 14:50
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45708918]: BUILD FAILED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45708918]: BUILD PASSED

output = super().run(_input)

if data_input.ndim == 3:
# DALI requires batch size to be present
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# DALI requires batch size to be present
# Removing the batch dimension we added above

or similar would be more readable.

device_id = data_input.device.index
else:
device_id = torch.cuda.current_device()
stream = torch.cuda.Stream(device=device_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are creating a new stream every time, and we don't really synchronize with it or return it. Is this intentional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines +34 to +40
def _to_torch_tensor(tensor_or_tl: TensorListGPU | TensorListCPU) -> torch.Tensor:
if isinstance(tensor_or_tl, (TensorListGPU, TensorListCPU)):
dali_tensor = tensor_or_tl.as_tensor()
else:
dali_tensor = tensor_or_tl

return torch.from_dlpack(dali_tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure the produced TensorList/Tensor stays alive long enough? Should we clone? Should we document “tensor is only valid until next pipeline run” (if that’s the case)? Just open questions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dlpack should receive pycapsule that holds allocation and tie the reference lifetime to the returned object, the new executor is able to transfer ownership of its allocations, so it should be fine.

return output


def adjust_input(func):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Layout handling is inconsistent across the stack (HW vs HWC for grayscale)
adjust_input converts PIL L to layout "HW" (no channel dim).
But other parts (e.g., functional.resize) treat "HW" as acceptable alongside "CHW"/"NCHW", and compute original_h/original_w accordingly.
Meanwhile PipelineHWC.run() converts PIL L into a tensor with a channel dim (unsqueeze(-1)), i.e. effectively HWC.
So grayscale is "HW" in dynamic functional path but HWC in pipeline path.

Just verifying that this is intentional. Should we document it? Should we stick to one representation for all cases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed - HWC will be used.

Comment on lines +113 to +116
# TODO:
# assert torch.allclose(out_tv, out_dali_tv, rtol=1, atol=1)
# assert torch.allclose(out_fn, out_dali_fn, rtol=1, atol=1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we plan to add that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but currently changing interpolation or antialiasing returns different results. I turned it on for tensors, for now, since these tests are not running with different interpolation or antialiasing settings.

def test_large_sizes_images(resize, device):
loop_images_test(resize=resize, device=device)


Copy link
Contributor

@jantonguirao jantonguirao Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing tests for:

  • CUDA tensors input to Compose (and correctness / synchronization)
  • Compose called repeatedly with different sizes / dtypes / devices
  • grayscale layout consistency between functional and operator-based APIs
  • non-uint8 dtypes, float tensors, etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

Comment on lines +182 to +183
arg_rules: Sequence[ArgumentVerificationRule] = []
input_rules: Sequence[DataVerificationRule] = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about making those tuples if they are not meant to be mutable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Signed-off-by: Marek Dabek <mdabek@nvidia.com>
@mdabek-nvidia mdabek-nvidia force-pushed the torchvision_api_infra branch from 38f8d3c to 54a3e57 Compare March 13, 2026 17:12
Signed-off-by: Marek Dabek <mdabek@nvidia.com>
Signed-off-by: Marek Dabek <mdabek@nvidia.com>
@mdabek-nvidia mdabek-nvidia force-pushed the torchvision_api_infra branch from a1be3e2 to cbf29d1 Compare March 14, 2026 18:04
@mdabek-nvidia
Copy link
Collaborator Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [46236785]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [46236785]: BUILD PASSED

Comment on lines +24 to +26
total = 1
for s in shape:
total *= s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
total = 1
for s in shape:
total *= s
total = math.prod(shape) # or np.prod(shape)

@params("gpu", "cpu")
def test_horizontal_random_flip_probability(device):
img = make_test_tensor()
transform = Compose([RandomHorizontalFlip(p=1.0, device=device)]) # always flip
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd echo what Szymon says about possible false impression that "this barely works" if one enounters such an error. The idea of overriding __call__ operator in experimental.trochvision.functional feels promising - the __call__ just throws, and we could have a private _call that calls super.

loop_images_test(resize=resize, device=device)


@cartesian_params((512, 1125, 2048, ([512, 512]), ([2048, 2048])), ("cpu", "gpu"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: parentheses without comma don't do anything, do they?

Suggested change
@cartesian_params((512, 1125, 2048, ([512, 512]), ([2048, 2048])), ("cpu", "gpu"))
@cartesian_params((512, 1125, 2048, [512, 512], [2048, 2048]), ("cpu", "gpu"))
Suggested change
@cartesian_params((512, 1125, 2048, ([512, 512]), ([2048, 2048])), ("cpu", "gpu"))
@cartesian_params((512, 1125, 2048, (512, 512), (2048, 2048)), ("cpu", "gpu"))

Returns the size in a canonical form:

- ``int`` — resize the shorter edge to this value (aspect-ratio preserving)
- ``None`` — use ``max_size`` only (resize so longer edge equals ``max_size``)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_size is unused here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch :)

orig_h = orig_size[0]
orig_w = orig_size[1]

if isinstance(size, (tuple, list)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if len(size) > 2?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I added exception

Comment on lines +34 to +40
def _to_torch_tensor(tensor_or_tl: TensorListGPU | TensorListCPU) -> torch.Tensor:
if isinstance(tensor_or_tl, (TensorListGPU, TensorListCPU)):
dali_tensor = tensor_or_tl.as_tensor()
else:
dali_tensor = tensor_or_tl

return torch.from_dlpack(dali_tensor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dlpack should receive pycapsule that holds allocation and tie the reference lifetime to the returned object, the new executor is able to transfer ownership of its allocations, so it should be fine.

Signed-off-by: Marek Dabek <mdabek@nvidia.com>
@mdabek-nvidia mdabek-nvidia force-pushed the torchvision_api_infra branch from 423fc2a to 19fdda1 Compare March 18, 2026 12:00
@mdabek-nvidia
Copy link
Collaborator Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [46425218]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [46425218]: BUILD PASSED

@mdabek-nvidia mdabek-nvidia merged commit 32a688c into NVIDIA:main Mar 18, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants