|
62 | 62 | load_or_create_model_card, |
63 | 63 | populate_model_card, |
64 | 64 | ) |
65 | | -from ..utils.torch_utils import device_synchronize, empty_device_cache |
| 65 | +from ..utils.torch_utils import empty_device_cache |
66 | 66 | from .model_loading_utils import ( |
67 | 67 | _caching_allocator_warmup, |
68 | 68 | _determine_device_map, |
@@ -1540,10 +1540,7 @@ def _load_pretrained_model( |
1540 | 1540 | assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) |
1541 | 1541 | error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) |
1542 | 1542 |
|
1543 | | - # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is |
1544 | | - # required because we move tensors with non_blocking=True, which is slightly faster for model loading. |
1545 | 1543 | empty_device_cache() |
1546 | | - device_synchronize() |
1547 | 1544 |
|
1548 | 1545 | if offload_index is not None and len(offload_index) > 0: |
1549 | 1546 | save_offload_index(offload_index, offload_folder) |
@@ -1880,4 +1877,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P |
1880 | 1877 | # resolve remapping |
1881 | 1878 | remapped_class = _fetch_remapped_cls_from_config(config, cls) |
1882 | 1879 |
|
1883 | | - return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy) |
| 1880 | + if remapped_class is cls: |
| 1881 | + return super(LegacyModelMixin, remapped_class).from_pretrained( |
| 1882 | + pretrained_model_name_or_path, **kwargs_copy |
| 1883 | + ) |
| 1884 | + else: |
| 1885 | + return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy) |
0 commit comments