Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 57 additions & 53 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from typing import Optional, Union

import numpy as np
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
Expand All @@ -15,6 +15,7 @@
cast_trt_tensor,
get_trt_tensor,
has_dynamic_shape,
set_layer_name,
)
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
convert_binary_elementwise,
Expand All @@ -23,6 +24,8 @@
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
from torch_tensorrt.fx.types import TRTTensor

_LOGGER = logging.getLogger(__name__)


def trunc_div(
ctx: ConversionContext,
Expand Down Expand Up @@ -250,12 +253,26 @@ def atan2(
A TensorRT tensor representing the result of the atan2 operation.
"""
pi_value = 3.141592653589793
pi_tensor = get_trt_tensor(ctx, pi_value, f"{name}_pi")

if isinstance(input, TRTTensor):
input = cast_trt_tensor(ctx, input, trt.float32, f"{name}_input")
if isinstance(other, TRTTensor):
other = cast_trt_tensor(ctx, other, trt.float32, f"{name}_other")
promoted_type = _enums.dtype._from(
torch.promote_types(
_enums.dtype._from(input.dtype).to(torch.dtype),
_enums.dtype._from(other.dtype).to(torch.dtype),
)
)
# atan2's output is always float, so we promote any integer types to float32
# This mirrors PyTorch's behavior where atan2(int, int) -> float.
if not promoted_type.to(torch.dtype).is_floating_point:
promoted_type = _enums.dtype.float32

trt_promoted_type = promoted_type.to(trt.DataType)

pi_tensor = get_trt_tensor(ctx, pi_value, f"{name}_pi", dtype=trt_promoted_type)

if input.dtype != trt_promoted_type:
input = cast_trt_tensor(ctx, input, trt_promoted_type, f"{name}_input_casted")
if other.dtype != trt_promoted_type:
other = cast_trt_tensor(ctx, other, trt_promoted_type, f"{name}_other_casted")

input, other = broadcast(ctx, input, other, f"{name}_input", f"{name}_other")

Expand Down Expand Up @@ -333,56 +350,43 @@ def atan2(
y_positive,
)

# Create constant tensors for boundary conditions (x=0 or y=0)
# Use impl.full which handles both dynamic and static shapes efficiently.
if has_dynamic_shape(input.shape):
pi_over_2_tensor = convert_binary_elementwise(
ctx,
target,
source_ir,
f"{name}_pi_over_2_tensor",
trt.ElementWiseOperation.PROD,
(pi_value / 2),
input,
)

minus_pi_over_2_tensor = convert_binary_elementwise(
ctx,
target,
source_ir,
f"{name}_minus_pi_over_2_tensor",
trt.ElementWiseOperation.PROD,
(-pi_value / 2),
input,
)
zero_tensor = convert_binary_elementwise(
ctx,
target,
source_ir,
f"{name}_zero_tensor",
trt.ElementWiseOperation.PROD,
0,
input,
)
shape_layer = ctx.net.add_shape(input)
set_layer_name(shape_layer, target, f"{name}_shape", source_ir)
shape = shape_layer.get_output(0)
else:
# on x or y-axis
pi_over_2_tensor = get_trt_tensor(
ctx,
(pi_value / 2) * np.ones(input.shape, dtype=np.float32),
f"{name}_pi_over_2_tensor",
dtype=trt.float32,
)
shape = list(input.shape)

minus_pi_over_2_tensor = get_trt_tensor(
ctx,
(-pi_value / 2) * np.ones(input.shape, dtype=np.float32),
f"{name}_minus_pi_over_2_tensor",
dtype=trt.float32,
)
zero_tensor = get_trt_tensor(
ctx,
np.zeros(input.shape, dtype=np.float32),
f"{name}_zero_tensor",
dtype=trt.float32,
)
pi_over_2_tensor = impl.full.full(
ctx,
target,
source_ir,
f"{name}_pi_over_2_tensor",
shape,
pi_value / 2,
dtype=trt_promoted_type,
)

minus_pi_over_2_tensor = impl.full.full(
ctx,
target,
source_ir,
f"{name}_minus_pi_over_2_tensor",
shape,
-pi_value / 2,
dtype=trt_promoted_type,
)
zero_tensor = impl.full.full(
ctx,
target,
source_ir,
f"{name}_zero_tensor",
shape,
0.0,
dtype=trt_promoted_type,
)

# π/2 if x>0 and y=0,
pi_over_2_output = impl.condition.select(
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def full(
# in static shape scenario, shape is a list of int
if all(isinstance(dim, int) for dim in shape):
output_np_dtype = output_dtype.try_to(np.dtype, use_default=True)
return np.full(shape, fill_value, dtype=output_np_dtype)
np_array = np.full(shape, fill_value, dtype=output_np_dtype)
return get_trt_tensor(ctx, np_array, name, dtype=output_dtype)
else:
shape = impl.cat.cat(
ctx, target, source_ir, name + "_concat_shape", shape, 0
Expand Down
Loading