|
42 | 42 | from ..quantizers.quantization_config import QuantizationMethod |
43 | 43 | from ..utils import ( |
44 | 44 | CONFIG_NAME, |
| 45 | + FLASHPACK_WEIGHTS_NAME, |
45 | 46 | FLAX_WEIGHTS_NAME, |
46 | 47 | HF_ENABLE_PARALLEL_LOADING, |
47 | 48 | SAFE_WEIGHTS_INDEX_NAME, |
|
55 | 56 | is_accelerate_available, |
56 | 57 | is_bitsandbytes_available, |
57 | 58 | is_bitsandbytes_version, |
| 59 | + is_flashpack_available, |
58 | 60 | is_peft_available, |
59 | 61 | is_torch_version, |
60 | 62 | logging, |
@@ -913,6 +915,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P |
913 | 915 | disable_mmap ('bool', *optional*, defaults to 'False'): |
914 | 916 | Whether to disable mmap when loading a Safetensors model. This option can perform better when the model |
915 | 917 | is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. |
| 918 | + use_flashpack (`bool`, *optional*, defaults to `False`): |
| 919 | + If set to `True`, the model is loaded from `flashpack` weights. |
916 | 920 |
|
917 | 921 | > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in |
918 | 922 | with `hf > auth login`. You can also activate the special > |
@@ -957,6 +961,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P |
957 | 961 | dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) |
958 | 962 | disable_mmap = kwargs.pop("disable_mmap", False) |
959 | 963 | parallel_config: Optional[Union[ParallelConfig, ContextParallelConfig]] = kwargs.pop("parallel_config", None) |
| 964 | + use_flashpack = kwargs.pop("use_flashpack", False) |
960 | 965 |
|
961 | 966 | is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING |
962 | 967 | if is_parallel_loading_enabled and not low_cpu_mem_usage: |
@@ -1185,6 +1190,30 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P |
1185 | 1190 | subfolder=subfolder or "", |
1186 | 1191 | dduf_entries=dduf_entries, |
1187 | 1192 | ) |
| 1193 | + elif use_flashpack: |
| 1194 | + try: |
| 1195 | + resolved_model_file = _get_model_file( |
| 1196 | + pretrained_model_name_or_path, |
| 1197 | + weights_name=FLASHPACK_WEIGHTS_NAME, |
| 1198 | + cache_dir=cache_dir, |
| 1199 | + force_download=force_download, |
| 1200 | + proxies=proxies, |
| 1201 | + local_files_only=local_files_only, |
| 1202 | + token=token, |
| 1203 | + revision=revision, |
| 1204 | + subfolder=subfolder, |
| 1205 | + user_agent=user_agent, |
| 1206 | + commit_hash=commit_hash, |
| 1207 | + dduf_entries=dduf_entries, |
| 1208 | + ) |
| 1209 | + |
| 1210 | + except IOError as e: |
| 1211 | + logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") |
| 1212 | + if not allow_pickle: |
| 1213 | + raise |
| 1214 | + logger.warning( |
| 1215 | + "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." |
| 1216 | + ) |
1188 | 1217 | elif use_safetensors: |
1189 | 1218 | try: |
1190 | 1219 | resolved_model_file = _get_model_file( |
@@ -1248,6 +1277,32 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P |
1248 | 1277 | with ContextManagers(init_contexts): |
1249 | 1278 | model = cls.from_config(config, **unused_kwargs) |
1250 | 1279 |
|
| 1280 | + if use_flashpack: |
| 1281 | + if is_flashpack_available(): |
| 1282 | + import flashpack |
| 1283 | + |
| 1284 | + flashpack.mixin.assign_from_file( |
| 1285 | + model=model, |
| 1286 | + path=resolved_model_file[0], |
| 1287 | + device=None if device_map is None else device_map[""], |
| 1288 | + # silent=silent, |
| 1289 | + # strict=strict, |
| 1290 | + # strict_params=strict_params, |
| 1291 | + # strict_buffers=strict_buffers, |
| 1292 | + # keep_flash_ref_on_model=keep_flash_ref_on_model, |
| 1293 | + # num_streams=num_streams, |
| 1294 | + # chunk_bytes=chunk_bytes, |
| 1295 | + # ignore_names=ignore_names or cls.flashpack_ignore_names, |
| 1296 | + # ignore_prefixes=ignore_prefixes or cls.flashpack_ignore_prefixes, |
| 1297 | + # ignore_suffixes=ignore_suffixes or cls.flashpack_ignore_suffixes, |
| 1298 | + # use_distributed_loading=use_distributed_loading, |
| 1299 | + # rank=rank, |
| 1300 | + # local_rank=local_rank, |
| 1301 | + # world_size=world_size, |
| 1302 | + # coerce_dtype=coerce_dtype or cls.flashpack_coerce_dtype, |
| 1303 | + ) |
| 1304 | + return model |
| 1305 | + |
1251 | 1306 | if dtype_orig is not None: |
1252 | 1307 | torch.set_default_dtype(dtype_orig) |
1253 | 1308 |
|
|
0 commit comments