|
29 | 29 | DTensor = None |
30 | 30 |
|
31 | 31 | import torch.nn.functional as F |
32 | | -from packaging.version import Version |
33 | 32 | from torch import nn |
34 | 33 | from torch.onnx._globals import GLOBALS |
35 | 34 |
|
@@ -1023,48 +1022,6 @@ def extra_repr(self): |
1023 | 1022 | s += " calib" if (self._if_calib) else "" |
1024 | 1023 | return s |
1025 | 1024 |
|
1026 | | - @property |
1027 | | - def mopt_ckpt_versn(self): |
1028 | | - """Version of the checkpoint if it is restored from a checkpoint.""" |
1029 | | - return getattr(self, "_mopt_ckpt_versn", None) |
1030 | | - |
1031 | | - @mopt_ckpt_versn.setter |
1032 | | - def mopt_ckpt_versn(self, version: str): |
1033 | | - self._mopt_ckpt_versn = str(version) |
1034 | | - |
1035 | | - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): |
1036 | | - """Special handling for loading older checkpoints. |
1037 | | -
|
1038 | | - This implementation is for backward compatibility and can be deprecated in future versions. |
1039 | | -
|
1040 | | - Args: |
1041 | | - state_dict: A dict containing the state of the top level module |
1042 | | - prefix: A string that prefixes all of this modules state in state_dict, e.g. 'model.conv1.' |
1043 | | - """ |
1044 | | - if self.mopt_ckpt_versn is None or Version(self.mopt_ckpt_versn) >= Version("0.29"): |
1045 | | - # Warnings below are raised if users use partial state dictionary intentionally (eg:- HF ckpts) |
1046 | | - # For ModelOpt >= 0.29, the buffers will be correctly created, So lets skip the warnings |
1047 | | - return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
1048 | | - |
1049 | | - _attrs = ["_amax", "_pre_quant_scale", "_svdquant_lora_a", "_svdquant_lora_b"] |
1050 | | - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
1051 | | - |
1052 | | - for attr in _attrs: |
1053 | | - has_dst = attr in self._buffers |
1054 | | - has_src = prefix + attr in state_dict |
1055 | | - |
1056 | | - if not has_src and has_dst: |
1057 | | - warnings.warn(f"{prefix[:-1]}: No {attr} in state_dict.") |
1058 | | - elif has_src and not has_dst: |
1059 | | - warnings.warn( |
1060 | | - f"{prefix[:-1]}: No '{attr}' buffer to load {attr} into." |
1061 | | - f" '{attr}` is created as a buffer for now. Please move the model to the correct device and " |
1062 | | - "dtype after this by calling `model.to(device, dtype)`." |
1063 | | - ) |
1064 | | - self.register_buffer(attr, state_dict[prefix + attr].clone().detach().to(device)) |
1065 | | - |
1066 | | - super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
1067 | | - |
1068 | 1025 | def _get_properties_for_modelopt_state(self): |
1069 | 1026 | return ( |
1070 | 1027 | self.__dict__.keys() |
|
0 commit comments