Skip to content

Commit f566fc1

Browse files
fix guard_fn issue (#3815)
1 parent 7711ffe commit f566fc1

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

tests/py/dynamo/conversion/harness.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,17 @@
2626
post_lowering,
2727
pre_export_lowering,
2828
)
29+
from torch_tensorrt.dynamo.lowering.passes import remove_num_users_is_0_nodes
2930
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
3031
from torch_tensorrt.dynamo.utils import ATOL, RTOL, get_model_device, get_torch_inputs
3132

3233
_LOGGER: logging.Logger = logging.getLogger(__name__)
3334

35+
# this is the post lowering pass list for the converter test
36+
post_lowering_pass_list_for_converter_test = [
37+
remove_num_users_is_0_nodes,
38+
]
39+
3440

3541
# this method is only used in our converter test to infer the module output dtypes via dummy inference
3642
# which is due to fx.symbolic_trace does not have the meta['val'] info in the node
@@ -435,6 +441,8 @@ def run_test(
435441
settings=compilation_settings,
436442
)
437443

444+
for pass_func in post_lowering_pass_list_for_converter_test:
445+
mod = pass_func(mod, compilation_settings)
438446
num_inputs = len(inputs)
439447
trt_inputs = inputs
440448
dtype_to_change = []

0 commit comments

Comments
 (0)