From 427fcb3709b2f0b476c48f383c42c4d41d05099b Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Thu, 24 Jul 2025 13:54:25 +0000 Subject: [PATCH 1/2] fix: Remove type casting in matmul and add scalar tensor conversion --- .../dynamo/conversion/impl/elementwise/base.py | 9 ++++++++- py/torch_tensorrt/dynamo/conversion/impl/matmul.py | 11 ----------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index ab9629b0db..030d3638bd 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -2,7 +2,6 @@ import warnings from typing import Any, Callable, Optional, Union -import numpy as np import tensorrt as trt import torch from torch.fx.node import Target @@ -103,6 +102,14 @@ def convert_binary_elementwise( rhs_dtype = rhs_val.dtype is_rhs_trt_tensor = True + # Convert 0-dimensional tensors (scalars) to Python scalars + # This ensures proper dtype promotion: scalar operands adopt the tensor's dtype + # following PyTorch's type promotion rules for arithmetic operations + if isinstance(lhs_val, torch.Tensor) and len(lhs_val.shape) == 0: + lhs_val = lhs_val.item() + if isinstance(rhs_val, torch.Tensor) and len(rhs_val.shape) == 0: + rhs_val = rhs_val.item() + if not is_lhs_trt_tensor and not is_rhs_trt_tensor: warnings.warn( f"Both operands of the binary elementwise op {name} " diff --git a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py index 83ea3dd99b..dfc917bc00 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py @@ -8,7 +8,6 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( broadcast, - cast_trt_tensor, get_trt_tensor, set_layer_name, ) @@ -48,16 +47,6 @@ def matrix_multiply( input, other = broadcast( ctx, input, other, f"{name}_input", f"{name}_other", preset_diff ) - if ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED): - promoted_type = _enums.dtype._from( - torch.promote_types( - _enums.dtype._from(input.dtype).to(torch.dtype), - _enums.dtype._from(other.dtype).to(torch.dtype), - ) - ) - trt_promoted_type = promoted_type.to(trt.DataType) - input = cast_trt_tensor(ctx, input, trt_promoted_type, f"{name}_input_casted") - other = cast_trt_tensor(ctx, other, trt_promoted_type, f"{name}_other_casted") layer = ctx.net.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op) set_layer_name(layer, target, name, source_ir) From 1b8dbd8f251198e0d68bf9dbee64a00cdc258f21 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Tue, 29 Jul 2025 15:47:23 +0900 Subject: [PATCH 2/2] chore: Handle scalar tensor type promotion --- .../dynamo/conversion/impl/elementwise/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 030d3638bd..42880646c3 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -102,13 +102,13 @@ def convert_binary_elementwise( rhs_dtype = rhs_val.dtype is_rhs_trt_tensor = True - # Convert 0-dimensional tensors (scalars) to Python scalars - # This ensures proper dtype promotion: scalar operands adopt the tensor's dtype - # following PyTorch's type promotion rules for arithmetic operations + # Handle scalar tensor type promotion for elementwise operations + # When one operand is a scalar tensor (0-dimensional), promote its dtype to match the other operand + # This ensures consistent type handling in Torch elementwise operations if isinstance(lhs_val, torch.Tensor) and len(lhs_val.shape) == 0: - lhs_val = lhs_val.item() + lhs_dtype = rhs_dtype if isinstance(rhs_val, torch.Tensor) and len(rhs_val.shape) == 0: - rhs_val = rhs_val.item() + rhs_dtype = lhs_dtype if not is_lhs_trt_tensor and not is_rhs_trt_tensor: warnings.warn(