|  | 
| 14 | 14 | # See the License for the specific language governing permissions and | 
| 15 | 15 | # limitations under the License. | 
| 16 | 16 | 
 | 
|  | 17 | +import functools | 
| 17 | 18 | import importlib | 
| 18 | 19 | import inspect | 
| 19 | 20 | import math | 
| 20 | 21 | import os | 
| 21 | 22 | from array import array | 
| 22 | 23 | from collections import OrderedDict, defaultdict | 
|  | 24 | +from concurrent.futures import ThreadPoolExecutor, as_completed | 
| 23 | 25 | from pathlib import Path | 
| 24 | 26 | from typing import Dict, List, Optional, Union | 
| 25 | 27 | from zipfile import is_zipfile | 
|  | 
| 31 | 33 | 
 | 
| 32 | 34 | from ..quantizers import DiffusersQuantizer | 
| 33 | 35 | from ..utils import ( | 
|  | 36 | +    DEFAULT_HF_PARALLEL_LOADING_WORKERS, | 
| 34 | 37 |     GGUF_FILE_EXTENSION, | 
| 35 | 38 |     SAFE_WEIGHTS_INDEX_NAME, | 
| 36 | 39 |     SAFETENSORS_FILE_EXTENSION, | 
| @@ -310,6 +313,161 @@ def load_model_dict_into_meta( | 
| 310 | 313 |     return offload_index, state_dict_index | 
| 311 | 314 | 
 | 
| 312 | 315 | 
 | 
|  | 316 | +def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): | 
|  | 317 | +    """ | 
|  | 318 | +    Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first | 
|  | 319 | +    checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's | 
|  | 320 | +    parameters. | 
|  | 321 | +
 | 
|  | 322 | +    """ | 
|  | 323 | +    if model_to_load.device.type == "meta": | 
|  | 324 | +        return False | 
|  | 325 | + | 
|  | 326 | +    if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: | 
|  | 327 | +        return False | 
|  | 328 | + | 
|  | 329 | +    # Some models explicitly do not support param buffer assignment | 
|  | 330 | +    if not getattr(model_to_load, "_supports_param_buffer_assignment", True): | 
|  | 331 | +        logger.debug( | 
|  | 332 | +            f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" | 
|  | 333 | +        ) | 
|  | 334 | +        return False | 
|  | 335 | + | 
|  | 336 | +    # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype | 
|  | 337 | +    first_key = next(iter(model_to_load.state_dict().keys())) | 
|  | 338 | +    if start_prefix + first_key in state_dict: | 
|  | 339 | +        return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype | 
|  | 340 | + | 
|  | 341 | +    return False | 
|  | 342 | + | 
|  | 343 | + | 
|  | 344 | +def _load_shard_file( | 
|  | 345 | +    shard_file, | 
|  | 346 | +    model, | 
|  | 347 | +    model_state_dict, | 
|  | 348 | +    device_map=None, | 
|  | 349 | +    dtype=None, | 
|  | 350 | +    hf_quantizer=None, | 
|  | 351 | +    keep_in_fp32_modules=None, | 
|  | 352 | +    dduf_entries=None, | 
|  | 353 | +    loaded_keys=None, | 
|  | 354 | +    unexpected_keys=None, | 
|  | 355 | +    offload_index=None, | 
|  | 356 | +    offload_folder=None, | 
|  | 357 | +    state_dict_index=None, | 
|  | 358 | +    state_dict_folder=None, | 
|  | 359 | +    ignore_mismatched_sizes=False, | 
|  | 360 | +    low_cpu_mem_usage=False, | 
|  | 361 | +): | 
|  | 362 | +    state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) | 
|  | 363 | +    mismatched_keys = _find_mismatched_keys( | 
|  | 364 | +        state_dict, | 
|  | 365 | +        model_state_dict, | 
|  | 366 | +        loaded_keys, | 
|  | 367 | +        ignore_mismatched_sizes, | 
|  | 368 | +    ) | 
|  | 369 | +    error_msgs = [] | 
|  | 370 | +    if low_cpu_mem_usage: | 
|  | 371 | +        offload_index, state_dict_index = load_model_dict_into_meta( | 
|  | 372 | +            model, | 
|  | 373 | +            state_dict, | 
|  | 374 | +            device_map=device_map, | 
|  | 375 | +            dtype=dtype, | 
|  | 376 | +            hf_quantizer=hf_quantizer, | 
|  | 377 | +            keep_in_fp32_modules=keep_in_fp32_modules, | 
|  | 378 | +            unexpected_keys=unexpected_keys, | 
|  | 379 | +            offload_folder=offload_folder, | 
|  | 380 | +            offload_index=offload_index, | 
|  | 381 | +            state_dict_index=state_dict_index, | 
|  | 382 | +            state_dict_folder=state_dict_folder, | 
|  | 383 | +        ) | 
|  | 384 | +    else: | 
|  | 385 | +        assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) | 
|  | 386 | + | 
|  | 387 | +        error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) | 
|  | 388 | +    return offload_index, state_dict_index, mismatched_keys, error_msgs | 
|  | 389 | + | 
|  | 390 | + | 
|  | 391 | +def _load_shard_files_with_threadpool( | 
|  | 392 | +    shard_files, | 
|  | 393 | +    model, | 
|  | 394 | +    model_state_dict, | 
|  | 395 | +    device_map=None, | 
|  | 396 | +    dtype=None, | 
|  | 397 | +    hf_quantizer=None, | 
|  | 398 | +    keep_in_fp32_modules=None, | 
|  | 399 | +    dduf_entries=None, | 
|  | 400 | +    loaded_keys=None, | 
|  | 401 | +    unexpected_keys=None, | 
|  | 402 | +    offload_index=None, | 
|  | 403 | +    offload_folder=None, | 
|  | 404 | +    state_dict_index=None, | 
|  | 405 | +    state_dict_folder=None, | 
|  | 406 | +    ignore_mismatched_sizes=False, | 
|  | 407 | +    low_cpu_mem_usage=False, | 
|  | 408 | +): | 
|  | 409 | +    # Do not spawn anymore workers than you need | 
|  | 410 | +    num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS) | 
|  | 411 | + | 
|  | 412 | +    logger.info(f"Loading model weights in parallel with {num_workers} workers...") | 
|  | 413 | + | 
|  | 414 | +    error_msgs = [] | 
|  | 415 | +    mismatched_keys = [] | 
|  | 416 | + | 
|  | 417 | +    load_one = functools.partial( | 
|  | 418 | +        _load_shard_file, | 
|  | 419 | +        model=model, | 
|  | 420 | +        model_state_dict=model_state_dict, | 
|  | 421 | +        device_map=device_map, | 
|  | 422 | +        dtype=dtype, | 
|  | 423 | +        hf_quantizer=hf_quantizer, | 
|  | 424 | +        keep_in_fp32_modules=keep_in_fp32_modules, | 
|  | 425 | +        dduf_entries=dduf_entries, | 
|  | 426 | +        loaded_keys=loaded_keys, | 
|  | 427 | +        unexpected_keys=unexpected_keys, | 
|  | 428 | +        offload_index=offload_index, | 
|  | 429 | +        offload_folder=offload_folder, | 
|  | 430 | +        state_dict_index=state_dict_index, | 
|  | 431 | +        state_dict_folder=state_dict_folder, | 
|  | 432 | +        ignore_mismatched_sizes=ignore_mismatched_sizes, | 
|  | 433 | +        low_cpu_mem_usage=low_cpu_mem_usage, | 
|  | 434 | +    ) | 
|  | 435 | + | 
|  | 436 | +    with ThreadPoolExecutor(max_workers=num_workers) as executor: | 
|  | 437 | +        with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar: | 
|  | 438 | +            futures = [executor.submit(load_one, shard_file) for shard_file in shard_files] | 
|  | 439 | +            for future in as_completed(futures): | 
|  | 440 | +                result = future.result() | 
|  | 441 | +                offload_index, state_dict_index, _mismatched_keys, _error_msgs = result | 
|  | 442 | +                error_msgs += _error_msgs | 
|  | 443 | +                mismatched_keys += _mismatched_keys | 
|  | 444 | +                pbar.update(1) | 
|  | 445 | + | 
|  | 446 | +    return offload_index, state_dict_index, mismatched_keys, error_msgs | 
|  | 447 | + | 
|  | 448 | + | 
|  | 449 | +def _find_mismatched_keys( | 
|  | 450 | +    state_dict, | 
|  | 451 | +    model_state_dict, | 
|  | 452 | +    loaded_keys, | 
|  | 453 | +    ignore_mismatched_sizes, | 
|  | 454 | +): | 
|  | 455 | +    mismatched_keys = [] | 
|  | 456 | +    if ignore_mismatched_sizes: | 
|  | 457 | +        for checkpoint_key in loaded_keys: | 
|  | 458 | +            model_key = checkpoint_key | 
|  | 459 | +            # If the checkpoint is sharded, we may not have the key here. | 
|  | 460 | +            if checkpoint_key not in state_dict: | 
|  | 461 | +                continue | 
|  | 462 | + | 
|  | 463 | +            if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape: | 
|  | 464 | +                mismatched_keys.append( | 
|  | 465 | +                    (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) | 
|  | 466 | +                ) | 
|  | 467 | +                del state_dict[checkpoint_key] | 
|  | 468 | +    return mismatched_keys | 
|  | 469 | + | 
|  | 470 | + | 
| 313 | 471 | def _load_state_dict_into_model( | 
| 314 | 472 |     model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False | 
| 315 | 473 | ) -> List[str]: | 
|  | 
0 commit comments