-
Notifications
You must be signed in to change notification settings - Fork 371
fix: prelu perf gap on Unet #3717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
2f19666
79b3153
cdc31d7
d1d18b9
cf9e7bd
11550df
dbb5c78
9745a53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -440,7 +440,7 @@ def check_weight_equal( | |
except Exception: | ||
return torch.all(sd_weight == network_weight) | ||
|
||
@needs_refit | ||
@needs_refit # type: ignore[misc] | ||
def _save_weight_mapping(self) -> None: | ||
""" | ||
Construct the weight name mapping from engine weight name to state_dict weight name. | ||
|
@@ -577,7 +577,7 @@ def _save_weight_mapping(self) -> None: | |
gc.collect() | ||
torch.cuda.empty_cache() | ||
|
||
@needs_refit | ||
@needs_refit # type: ignore[misc] | ||
def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: | ||
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine | ||
# if not self.compilation_settings.strip_engine_weights: | ||
|
@@ -605,7 +605,7 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No | |
), | ||
) | ||
|
||
@needs_refit | ||
@needs_refit # type: ignore[misc] | ||
def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: | ||
# query the cached TRT engine | ||
cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr] | ||
|
@@ -941,7 +941,14 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: | |
f"Specified output dtypes ({len(self.output_dtypes)}) differ from number of outputs ({len(outputs)})" | ||
) | ||
|
||
marked_outputs_ids = [] | ||
for i, output in enumerate(outputs): | ||
# In some cases, the same output tensor may be marked multiple times, such as _to_oppy, | ||
# so we skip marking if the output is already marked | ||
if id(output) in marked_outputs_ids: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where does this id function come from? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
continue | ||
marked_outputs_ids.append(id(output)) | ||
|
||
name = f"output{i}" | ||
|
||
output_dtype = dtype.unknown | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1094,7 +1094,7 @@ def aten_ops_clone_copy_dtype( | |
name, | ||
args[0], | ||
kwargs.get("dtype", args[0].dtype), | ||
force_layer=True, | ||
force_layer=False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there might be cases where we might need to actually force_layer=True. Do you know when that would be useful?. Also consider adding a comment here conveying that force_layer=False results in better performance. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you know when we need force_layer=True? My understanding is that 1) Since |
||
) | ||
|
||
|
||
|
@@ -1226,7 +1226,7 @@ def aten_ops_sum( | |
name, | ||
sum_, | ||
kwargs["output_dtype"], | ||
force_layer=True, | ||
force_layer=False, | ||
) | ||
else: | ||
return sum_ | ||
|
Uh oh!
There was an error while loading. Please reload this page.