Skip to content

Commit 33a53cf

Browse files
authored
fix: prelu perf gap on Unet (#3717)
1 parent 17afde4 commit 33a53cf

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def check_weight_equal(
440440
except Exception:
441441
return torch.all(sd_weight == network_weight)
442442

443-
@needs_refit
443+
@needs_refit # type: ignore[misc]
444444
def _save_weight_mapping(self) -> None:
445445
"""
446446
Construct the weight name mapping from engine weight name to state_dict weight name.
@@ -577,7 +577,7 @@ def _save_weight_mapping(self) -> None:
577577
gc.collect()
578578
torch.cuda.empty_cache()
579579

580-
@needs_refit
580+
@needs_refit # type: ignore[misc]
581581
def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
582582
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
583583
# if not self.compilation_settings.strip_engine_weights:
@@ -605,7 +605,7 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No
605605
),
606606
)
607607

608-
@needs_refit
608+
@needs_refit # type: ignore[misc]
609609
def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
610610
# query the cached TRT engine
611611
cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr]
@@ -941,7 +941,14 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
941941
f"Specified output dtypes ({len(self.output_dtypes)}) differ from number of outputs ({len(outputs)})"
942942
)
943943

944+
marked_outputs_ids = []
944945
for i, output in enumerate(outputs):
946+
# In some cases, the same output tensor may be marked multiple times, such as _to_copy,
947+
# so we skip marking if the output is already marked
948+
if id(output) in marked_outputs_ids:
949+
continue
950+
marked_outputs_ids.append(id(output))
951+
945952
name = f"output{i}"
946953

947954
output_dtype = dtype.unknown

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,7 +1096,7 @@ def aten_ops_clone_copy_dtype(
10961096
name,
10971097
args[0],
10981098
kwargs.get("dtype", args[0].dtype),
1099-
force_layer=True,
1099+
force_layer=False, # force_layer=False results in better performance
11001100
)
11011101

11021102

@@ -1228,7 +1228,7 @@ def aten_ops_sum(
12281228
name,
12291229
sum_,
12301230
kwargs["output_dtype"],
1231-
force_layer=True,
1231+
force_layer=False, # force_layer=False results in better performance
12321232
)
12331233
else:
12341234
return sum_

0 commit comments

Comments
 (0)