Skip to content

Commit 2a0d40f

Browse files
fix L0 RTX test issues (#3894)
1 parent 7b9eda4 commit 2a0d40f

File tree

3 files changed

+17
-16
lines changed

3 files changed

+17
-16
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from tensorrt import ITensor as TRTTensor
1010
from torch.fx.node import Argument, Node, Target
11+
from torch_tensorrt import ENABLED_FEATURES
1112
from torch_tensorrt._features import needs_not_tensorrt_rtx
1213
from torch_tensorrt._utils import is_tensorrt_version_supported, is_thor
1314
from torch_tensorrt.dynamo._settings import CompilationSettings
@@ -427,8 +428,8 @@ def index_dtype_validator(
427428
def index_nonbool_validator(
428429
node: Node, settings: Optional[CompilationSettings] = None
429430
) -> bool:
430-
# for thor, we don't support boolean indices
431-
if is_thor():
431+
# for thor and tensorrt_rtx, we don't support boolean indices, due to nonzero op not supported
432+
if is_thor() or ENABLED_FEATURES.tensorrt_rtx:
432433
index = node.args[1]
433434
for ind in index:
434435
if ind is not None:
@@ -903,6 +904,8 @@ def aten_ops_select(
903904

904905
@dynamo_tensorrt_converter(
905906
torch.ops.aten.index_put.default,
907+
capability_validator=lambda node, settings: index_dtype_validator(node, settings)
908+
and index_nonbool_validator(node, settings),
906909
supports_dynamic_shapes=True,
907910
)
908911
@enforce_tensor_types(
@@ -2786,6 +2789,7 @@ def aten_ops_max_pool(
27862789
@dynamo_tensorrt_converter(
27872790
torch.ops.aten._reshape_copy.default, supports_dynamic_shapes=True
27882791
)
2792+
@dynamo_tensorrt_converter(torch.ops.aten.view.default, supports_dynamic_shapes=True)
27892793
@enforce_tensor_types(
27902794
{
27912795
0: (TRTTensor,),

tests/py/dynamo/conversion/test_index_aten.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch_tensorrt
66
from parameterized import parameterized
77
from torch.testing._internal.common_utils import run_tests
8-
from torch_tensorrt import Input
8+
from torch_tensorrt import ENABLED_FEATURES, Input
99
from torch_tensorrt._utils import is_tegra_platform, is_thor
1010

1111
from .harness import DispatchTestCase
@@ -114,8 +114,8 @@ def forward(self, input):
114114
]
115115
)
116116
@unittest.skipIf(
117-
is_thor(),
118-
"Skipped on Thor due to nonzero not supported",
117+
is_thor() or ENABLED_FEATURES.tensorrt_rtx,
118+
"Skipped on Thor or tensorrt_rtx due to nonzero not supported",
119119
)
120120
def test_index_constant_bool_mask(self, _, index, input):
121121
class TestModule(torch.nn.Module):
@@ -148,6 +148,10 @@ def forward(self, x, index0):
148148
[input, index0],
149149
)
150150

151+
@unittest.skipIf(
152+
is_thor() or ENABLED_FEATURES.tensorrt_rtx,
153+
"Skipped on Thor or tensorrt_rtx due to nonzero not supported",
154+
)
151155
def test_index_zero_two_dim_ITensor_mask(self):
152156
class TestModule(nn.Module):
153157
def forward(self, x, index0):
@@ -176,8 +180,8 @@ def forward(self, x, index0):
176180
self.run_test(TestModule(), [input, index0])
177181

178182
@unittest.skipIf(
179-
is_thor(),
180-
"Skipped on Thor due to nonzero not supported",
183+
is_thor() or ENABLED_FEATURES.tensorrt_rtx,
184+
"Skipped on Thor or tensorrt_rtx due to nonzero not supported",
181185
)
182186
def test_index_zero_index_three_dim_mask_ITensor(self):
183187
class TestModule(nn.Module):

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1679,15 +1679,8 @@ def forward(self, x):
16791679
)
16801680
optimized_model_results = optimized_model(*inputs).detach().cpu()
16811681
torch_model_results = fx_graph(*inputs).detach().cpu()
1682-
1683-
max_diff = float(
1684-
torch.max(torch.abs(optimized_model_results - torch_model_results))
1685-
)
1686-
self.assertAlmostEqual(
1687-
max_diff,
1688-
0,
1689-
DECIMALS_OF_AGREEMENT,
1690-
f"Log_softmax TRT outputs don't match with the original model.",
1682+
assert torch.allclose(
1683+
optimized_model_results, torch_model_results, atol=1e-3, rtol=1e-3
16911684
)
16921685

16931686
@parameterized.expand(

0 commit comments

Comments
 (0)