Skip to content

Commit 0fc2a62

Browse files
committed
WAR the fact that it is not possible to set __setattr__ dynamically
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
1 parent 778019d commit 0fc2a62

File tree

6 files changed

+12
-6
lines changed

6 files changed

+12
-6
lines changed

transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unuse
482482

483483
self.register_load_state_dict_post_hook(remove_extra_states_check)
484484

485-
self.__setattr__ = self.default_setattr
485+
self._default_setattr = self._warning_setattr
486486

487487
def _load_from_state_dict(
488488
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs

transformer_engine/pytorch/module/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ def fast_setattr(self, name: str, value: Any) -> None:
639639
def module_setattr(self, name: str, value: Any) -> None:
640640
super().__setattr__(name, value)
641641

642-
def default_setattr(self, name: str, value: Any) -> None:
642+
def _warning_setattr(self, name: str, value: Any) -> None:
643643
warnings.warn(
644644
"""The default implementation of torch.nn.Module introduces significant CPU overhead
645645
when setting attributes and is therefore not recommended. Please use the explicit calls
@@ -649,6 +649,12 @@ def default_setattr(self, name: str, value: Any) -> None:
649649
)
650650
self.module_setattr(name, value)
651651

652+
def _default_setattr(self, name: str, value: Any) -> None:
653+
return self.module_setattr(name, value)
654+
655+
def __setattr__(self, name: str, value: Any) -> None:
656+
return self._default_setattr(name, value)
657+
652658
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
653659
"""
654660
Delayed scaling only.

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ def __init__(
716716
if name in (f"weight{i}", f"bias{i}"):
717717
param.skip_backward_post_hook = True
718718

719-
self.__setattr__ = self.default_setattr
719+
self._default_setattr = self._warning_setattr
720720

721721
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
722722
"""Init scales and amaxes for fwd | bwd."""

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1405,7 +1405,7 @@ def __init__(
14051405
if name in self.weight_names or name in self.bias_names:
14061406
param.skip_backward_post_hook = True
14071407

1408-
self.__setattr__ = self.default_setattr
1408+
self._default_setattr = self._warning_setattr
14091409

14101410
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
14111411
"""Init scales and amaxes for fwd | bwd."""

transformer_engine/pytorch/module/layernorm_mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1960,7 +1960,7 @@ def __init__(
19601960
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
19611961
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
19621962

1963-
self.__setattr__ = self.default_setattr
1963+
self._default_setattr = self._warning_setattr
19641964

19651965
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
19661966
"""Init scales and amaxes for fwd | bwd."""

transformer_engine/pytorch/module/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1309,7 +1309,7 @@ def __init__(
13091309
if name in self.weight_names or name in self.bias_names:
13101310
param.skip_backward_post_hook = True
13111311

1312-
self.__setattr__ = self.default_setattr
1312+
self._default_setattr = self._warning_setattr
13131313

13141314
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
13151315
"""Init scales and amaxes for fwd | bwd."""

0 commit comments

Comments
 (0)