Skip to content

Conversation

chohk88
Copy link
Collaborator

@chohk88 chohk88 commented Aug 6, 2025

Description

  • Removed forced float32 casting in the atan2 converter.
  • Fixed a bug in the dynamic shape logic for atan2: previously used π/2 * input instead of creating a constant tensor filled with π/2. This only passed tests because the input was zero.
  • Note: Replaced torch.rand() with torch.randint() for integer inputs in dynamic shape tests — this change may affect other converters using integer dynamic inputs.

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@chohk88 chohk88 self-assigned this Aug 6, 2025
@chohk88 chohk88 added the component: converters Issues re: Specific op converters label Aug 6, 2025
@meta-cla meta-cla bot added the cla signed label Aug 6, 2025
@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Aug 6, 2025
Copy link
Collaborator

@lanluo-nvidia lanluo-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
(I have verified this fix is working when use_explicit_typing is true)

@lanluo-nvidia lanluo-nvidia mentioned this pull request Aug 6, 2025
7 tasks
@chohk88
Copy link
Collaborator Author

chohk88 commented Aug 7, 2025

@lanluo-nvidia @narendasan

From the CI logs, it seems the issue might be caused by the change I made to generate integer random inputs for dynamic shapes. Previously, the inputs were all zeros, so the problem didn’t surface. I’ll look into fixing this part.

If strong type support for atan2 is urgent, I can remove the dynamic shape input change from this PR and submit it separately. Let me know what you think.

self = Input(shape=(2, 3, 4), dtype=dtype.b, format=memory_format.linear, domain=[0.0, 2.0))
optimization_profile_field = None

def example_tensor(
    self, optimization_profile_field: Optional[str] = None
) -> torch.Tensor:
    """
    Get an example tensor of the shape specified by the Input object

    Args:
        optimization_profile_field (Optional(str)): Name of the field to use for shape in the case the Input is dynamically shaped

    Returns:
        A PyTorch Tensor
    """
    if self.shape_mode == Input._ShapeMode.STATIC:
        if optimization_profile_field is not None:
            raise ValueError(
                "Specified a optimization profile field but the input is static"
            )
        else:
            if isinstance(self.shape, tuple):
                shape = self.shape
                dtype = self.dtype.to(torch.dtype, use_default=True)
                if dtype.is_floating_point:
                    return torch.rand(shape).to(dtype=dtype)
                else:
                    # For integer types, use randint to get a better range of values for testing
                  return torch.randint(-10, 10, shape, dtype=dtype)

E torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt_backend' raised:
E RuntimeError: from is out of bounds for bool
E
E Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
E
E
E To execute this test, run the following from the base repo dir:
E python test_decompositions.py TestLowering.test_masked_scatter_1_float16_3d
E
E This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

/opt/python/cp39-cp39/lib/python3.9/site-packages/torch_tensorrt/Input.py:389: BackendCompilerFailed
------------------------------ Captured log call -------------------------------
CRITICAL torch_tensorrt.dynamo.backend.backends:backends.py:180 Halting compilation on build failure since pass_through_build_failures was specified as True. To return the default Torch implementation and avoid halting compilation on engine build failures, specify pass_through_build_failures=False.
=============================== warnings summary ===============================
tests/py/dynamo/lowering/test_decompositions.py: 68 warnings
tests/py/dynamo/lowering/test_aten_lowering_passes.py: 3 warnings
/opt/python/cp39-cp39/lib/python3.9/site-packages/torch_tensorrt/dynamo/utils.py:275: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad
(True), rather than torch.tensor(sourceTensor).
torch.tensor(inputs),

tests/py/dynamo/lowering/test_decompositions.py: 68 warnings
tests/py/dynamo/lowering/test_aten_lowering_passes.py: 3 warnings
/opt/python/cp39-cp39/lib/python3.9/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py:976: DeprecationWarning: Use Deprecated in TensorRT 10.12. Superseded by strong typing. instead.
output.dtype = output_dtype.to(trt.DataType, use_default=True)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html

  • generated xml file: /tmp/test_results/tests_py_dynamo_core_lowering_test_results.xml -
    =========================== short test summary info ============================
    FAILED lowering/test_decompositions.py::TestLowering::test_masked_scatter_0_float32_2d - torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt_backend' raised:
    RuntimeError: from is out of bounds for bool

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

To execute this test, run the following from the base repo dir:
python test_decompositions.py TestLowering.test_masked_scatter_0_float32_2d

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
FAILED lowering/test_decompositions.py::TestLowering::test_masked_scatter_1_float16_3d - torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt_backend' raised:
RuntimeError: from is out of bounds for bool

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

To execute this test, run the following from the base repo dir:
python test_decompositions.py TestLowering.test_masked_scatter_1_float16_3d

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
================= 2 failed, 77 passed, 142 warnings in 59.93s ==================

@chohk88
Copy link
Collaborator Author

chohk88 commented Aug 11, 2025

I will create a separate issue for py/torch_tensorrt/_Input.py as revising the masked_scatter lowering and converter will take more time.

@chohk88 chohk88 merged commit 33b8bbc into main Aug 12, 2025
81 of 85 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants