|
29 | 29 | DTensor = None
|
30 | 30 |
|
31 | 31 | import torch.nn.functional as F
|
32 |
| -from packaging.version import Version |
33 | 32 | from torch import nn
|
34 | 33 | from torch.onnx._globals import GLOBALS
|
35 | 34 |
|
@@ -1023,48 +1022,6 @@ def extra_repr(self):
|
1023 | 1022 | s += " calib" if (self._if_calib) else ""
|
1024 | 1023 | return s
|
1025 | 1024 |
|
1026 |
| - @property |
1027 |
| - def mopt_ckpt_versn(self): |
1028 |
| - """Version of the checkpoint if it is restored from a checkpoint.""" |
1029 |
| - return getattr(self, "_mopt_ckpt_versn", None) |
1030 |
| - |
1031 |
| - @mopt_ckpt_versn.setter |
1032 |
| - def mopt_ckpt_versn(self, version: str): |
1033 |
| - self._mopt_ckpt_versn = str(version) |
1034 |
| - |
1035 |
| - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): |
1036 |
| - """Special handling for loading older checkpoints. |
1037 |
| -
|
1038 |
| - This implementation is for backward compatibility and can be deprecated in future versions. |
1039 |
| -
|
1040 |
| - Args: |
1041 |
| - state_dict: A dict containing the state of the top level module |
1042 |
| - prefix: A string that prefixes all of this modules state in state_dict, e.g. 'model.conv1.' |
1043 |
| - """ |
1044 |
| - if self.mopt_ckpt_versn is None or Version(self.mopt_ckpt_versn) >= Version("0.29"): |
1045 |
| - # Warnings below are raised if users use partial state dictionary intentionally (eg:- HF ckpts) |
1046 |
| - # For ModelOpt >= 0.29, the buffers will be correctly created, So lets skip the warnings |
1047 |
| - return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
1048 |
| - |
1049 |
| - _attrs = ["_amax", "_pre_quant_scale", "_svdquant_lora_a", "_svdquant_lora_b"] |
1050 |
| - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
1051 |
| - |
1052 |
| - for attr in _attrs: |
1053 |
| - has_dst = attr in self._buffers |
1054 |
| - has_src = prefix + attr in state_dict |
1055 |
| - |
1056 |
| - if not has_src and has_dst: |
1057 |
| - warnings.warn(f"{prefix[:-1]}: No {attr} in state_dict.") |
1058 |
| - elif has_src and not has_dst: |
1059 |
| - warnings.warn( |
1060 |
| - f"{prefix[:-1]}: No '{attr}' buffer to load {attr} into." |
1061 |
| - f" '{attr}` is created as a buffer for now. Please move the model to the correct device and " |
1062 |
| - "dtype after this by calling `model.to(device, dtype)`." |
1063 |
| - ) |
1064 |
| - self.register_buffer(attr, state_dict[prefix + attr].clone().detach().to(device)) |
1065 |
| - |
1066 |
| - super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
1067 |
| - |
1068 | 1025 | def _get_properties_for_modelopt_state(self):
|
1069 | 1026 | return (
|
1070 | 1027 | self.__dict__.keys()
|
|
0 commit comments