|
22 | 22 | from typing import Any, Dict, List, Optional, Union
|
23 | 23 |
|
24 | 24 | import torch
|
25 |
| -from huggingface_hub import ( |
26 |
| - model_info, |
27 |
| -) |
| 25 | +from huggingface_hub import model_info |
| 26 | +from huggingface_hub.utils import validate_hf_hub_args |
28 | 27 | from packaging import version
|
29 | 28 |
|
| 29 | +from .. import __version__ |
30 | 30 | from ..utils import (
|
| 31 | + FLAX_WEIGHTS_NAME, |
| 32 | + ONNX_EXTERNAL_WEIGHTS_NAME, |
| 33 | + ONNX_WEIGHTS_NAME, |
31 | 34 | SAFETENSORS_WEIGHTS_NAME,
|
32 | 35 | WEIGHTS_NAME,
|
33 | 36 | get_class_from_dynamic_module,
|
| 37 | + is_accelerate_available, |
34 | 38 | is_peft_available,
|
35 | 39 | is_transformers_available,
|
36 | 40 | logging,
|
|
44 | 48 | from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
45 | 49 | from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
46 | 50 | from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
47 |
| -from huggingface_hub.utils import validate_hf_hub_args |
48 | 51 |
|
49 |
| -from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME |
| 52 | +if is_accelerate_available(): |
| 53 | + import accelerate |
| 54 | + from accelerate import dispatch_model |
| 55 | + from accelerate.hooks import remove_hook_from_module |
| 56 | + from accelerate.utils import compute_module_sizes, get_max_memory |
50 | 57 |
|
51 | 58 |
|
52 | 59 | INDEX_FILE = "diffusion_pytorch_model.bin"
|
@@ -376,6 +383,207 @@ def _get_pipeline_class(
|
376 | 383 | return pipeline_cls
|
377 | 384 |
|
378 | 385 |
|
| 386 | +def _load_empty_model( |
| 387 | + library_name: str, |
| 388 | + class_name: str, |
| 389 | + importable_classes: List[Any], |
| 390 | + pipelines: Any, |
| 391 | + is_pipeline_module: bool, |
| 392 | + name: str, |
| 393 | + torch_dtype: Union[str, torch.dtype], |
| 394 | + cached_folder: Union[str, os.PathLike], |
| 395 | + **kwargs, |
| 396 | +): |
| 397 | + # retrieve class objects. |
| 398 | + class_obj, _ = get_class_obj_and_candidates( |
| 399 | + library_name, |
| 400 | + class_name, |
| 401 | + importable_classes, |
| 402 | + pipelines, |
| 403 | + is_pipeline_module, |
| 404 | + component_name=name, |
| 405 | + cache_dir=cached_folder, |
| 406 | + ) |
| 407 | + |
| 408 | + if is_transformers_available(): |
| 409 | + transformers_version = version.parse(version.parse(transformers.__version__).base_version) |
| 410 | + else: |
| 411 | + transformers_version = "N/A" |
| 412 | + |
| 413 | + # Determine library. |
| 414 | + is_transformers_model = ( |
| 415 | + is_transformers_available() |
| 416 | + and issubclass(class_obj, PreTrainedModel) |
| 417 | + and transformers_version >= version.parse("4.20.0") |
| 418 | + ) |
| 419 | + diffusers_module = importlib.import_module(__name__.split(".")[0]) |
| 420 | + is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) |
| 421 | + |
| 422 | + model = None |
| 423 | + config_path = cached_folder |
| 424 | + user_agent = { |
| 425 | + "diffusers": __version__, |
| 426 | + "file_type": "model", |
| 427 | + "framework": "pytorch", |
| 428 | + } |
| 429 | + |
| 430 | + if is_diffusers_model: |
| 431 | + # Load config and then the model on meta. |
| 432 | + config, unused_kwargs, commit_hash = class_obj.load_config( |
| 433 | + os.path.join(config_path, name), |
| 434 | + cache_dir=cached_folder, |
| 435 | + return_unused_kwargs=True, |
| 436 | + return_commit_hash=True, |
| 437 | + force_download=kwargs.pop("force_download", False), |
| 438 | + resume_download=kwargs.pop("resume_download", False), |
| 439 | + proxies=kwargs.pop("proxies", None), |
| 440 | + local_files_only=kwargs.pop("local_files_only", False), |
| 441 | + token=kwargs.pop("token", None), |
| 442 | + revision=kwargs.pop("revision", None), |
| 443 | + subfolder=kwargs.pop("subfolder", None), |
| 444 | + user_agent=user_agent, |
| 445 | + ) |
| 446 | + with accelerate.init_empty_weights(): |
| 447 | + model = class_obj.from_config(config, **unused_kwargs) |
| 448 | + elif is_transformers_model: |
| 449 | + config_class = getattr(class_obj, "config_class", None) |
| 450 | + if config_class is None: |
| 451 | + raise ValueError("`config_class` cannot be None. Please double-check the model.") |
| 452 | + |
| 453 | + config = config_class.from_pretrained( |
| 454 | + cached_folder, |
| 455 | + subfolder=name, |
| 456 | + force_download=kwargs.pop("force_download", False), |
| 457 | + resume_download=kwargs.pop("resume_download", False), |
| 458 | + proxies=kwargs.pop("proxies", None), |
| 459 | + local_files_only=kwargs.pop("local_files_only", False), |
| 460 | + token=kwargs.pop("token", None), |
| 461 | + revision=kwargs.pop("revision", None), |
| 462 | + user_agent=user_agent, |
| 463 | + ) |
| 464 | + with accelerate.init_empty_weights(): |
| 465 | + model = class_obj(config) |
| 466 | + |
| 467 | + if model is not None: |
| 468 | + model = model.to(dtype=torch_dtype) |
| 469 | + return model |
| 470 | + |
| 471 | + |
| 472 | +def _assign_components_to_devices( |
| 473 | + module_sizes: Dict[str, float], device_memory: Dict[str, float], device_mapping_strategy: str = "balanced" |
| 474 | +): |
| 475 | + device_ids = list(device_memory.keys()) |
| 476 | + device_cycle = device_ids + device_ids[::-1] |
| 477 | + device_memory = device_memory.copy() |
| 478 | + |
| 479 | + device_id_component_mapping = {} |
| 480 | + current_device_index = 0 |
| 481 | + for component in module_sizes: |
| 482 | + device_id = device_cycle[current_device_index % len(device_cycle)] |
| 483 | + component_memory = module_sizes[component] |
| 484 | + curr_device_memory = device_memory[device_id] |
| 485 | + |
| 486 | + # If the GPU doesn't fit the current component offload to the CPU. |
| 487 | + if component_memory > curr_device_memory: |
| 488 | + device_id_component_mapping["cpu"] = [component] |
| 489 | + else: |
| 490 | + if device_id not in device_id_component_mapping: |
| 491 | + device_id_component_mapping[device_id] = [component] |
| 492 | + else: |
| 493 | + device_id_component_mapping[device_id].append(component) |
| 494 | + |
| 495 | + # Update the device memory. |
| 496 | + device_memory[device_id] -= component_memory |
| 497 | + current_device_index += 1 |
| 498 | + |
| 499 | + return device_id_component_mapping |
| 500 | + |
| 501 | + |
| 502 | +def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs): |
| 503 | + # To avoid circular import problem. |
| 504 | + from diffusers import pipelines |
| 505 | + |
| 506 | + torch_dtype = kwargs.get("torch_dtype", torch.float32) |
| 507 | + |
| 508 | + # Load each module in the pipeline on a meta device so that we can derive the device map. |
| 509 | + init_empty_modules = {} |
| 510 | + for name, (library_name, class_name) in init_dict.items(): |
| 511 | + if class_name.startswith("Flax"): |
| 512 | + raise ValueError("Flax pipelines are not supported with `device_map`.") |
| 513 | + |
| 514 | + # Define all importable classes |
| 515 | + is_pipeline_module = hasattr(pipelines, library_name) |
| 516 | + importable_classes = ALL_IMPORTABLE_CLASSES |
| 517 | + loaded_sub_model = None |
| 518 | + |
| 519 | + # Use passed sub model or load class_name from library_name |
| 520 | + if name in passed_class_obj: |
| 521 | + # if the model is in a pipeline module, then we load it from the pipeline |
| 522 | + # check that passed_class_obj has correct parent class |
| 523 | + maybe_raise_or_warn( |
| 524 | + library_name, |
| 525 | + library, |
| 526 | + class_name, |
| 527 | + importable_classes, |
| 528 | + passed_class_obj, |
| 529 | + name, |
| 530 | + is_pipeline_module, |
| 531 | + ) |
| 532 | + with accelerate.init_empty_weights(): |
| 533 | + loaded_sub_model = passed_class_obj[name] |
| 534 | + |
| 535 | + else: |
| 536 | + loaded_sub_model = _load_empty_model( |
| 537 | + library_name=library_name, |
| 538 | + class_name=class_name, |
| 539 | + importable_classes=importable_classes, |
| 540 | + pipelines=pipelines, |
| 541 | + is_pipeline_module=is_pipeline_module, |
| 542 | + pipeline_class=pipeline_class, |
| 543 | + name=name, |
| 544 | + torch_dtype=torch_dtype, |
| 545 | + cached_folder=kwargs.get("cached_folder", None), |
| 546 | + force_download=kwargs.get("force_download", None), |
| 547 | + resume_download=kwargs.get("resume_download", None), |
| 548 | + proxies=kwargs.get("proxies", None), |
| 549 | + local_files_only=kwargs.get("local_files_only", None), |
| 550 | + token=kwargs.get("token", None), |
| 551 | + revision=kwargs.get("revision", None), |
| 552 | + ) |
| 553 | + |
| 554 | + if loaded_sub_model is not None: |
| 555 | + init_empty_modules[name] = loaded_sub_model |
| 556 | + |
| 557 | + # determine device map |
| 558 | + # Obtain a sorted dictionary for mapping the model-level components |
| 559 | + # to their sizes. |
| 560 | + module_sizes = { |
| 561 | + module_name: compute_module_sizes(module, dtype=torch_dtype)[""] |
| 562 | + for module_name, module in init_empty_modules.items() |
| 563 | + if isinstance(module, torch.nn.Module) |
| 564 | + } |
| 565 | + module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True)) |
| 566 | + |
| 567 | + # Obtain maximum memory available per device (GPUs only). |
| 568 | + max_memory = get_max_memory(max_memory) |
| 569 | + max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True)) |
| 570 | + max_memory = {k: v for k, v in max_memory.items() if k != "cpu"} |
| 571 | + |
| 572 | + # Obtain a dictionary mapping the model-level components to the available |
| 573 | + # devices based on the maximum memory and the model sizes. |
| 574 | + device_id_component_mapping = _assign_components_to_devices( |
| 575 | + module_sizes, max_memory, device_mapping_strategy=device_map |
| 576 | + ) |
| 577 | + |
| 578 | + # Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}` |
| 579 | + final_device_map = {} |
| 580 | + for device_id, components in device_id_component_mapping.items(): |
| 581 | + for component in components: |
| 582 | + final_device_map[component] = device_id |
| 583 | + |
| 584 | + return final_device_map |
| 585 | + |
| 586 | + |
379 | 587 | def load_sub_model(
|
380 | 588 | library_name: str,
|
381 | 589 | class_name: str,
|
@@ -493,6 +701,22 @@ def load_sub_model(
|
493 | 701 | # else load from the root directory
|
494 | 702 | loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
495 | 703 |
|
| 704 | + if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict): |
| 705 | + # remove hooks |
| 706 | + remove_hook_from_module(loaded_sub_model, recurse=True) |
| 707 | + needs_offloading_to_cpu = device_map[""] == "cpu" |
| 708 | + |
| 709 | + if needs_offloading_to_cpu: |
| 710 | + dispatch_model( |
| 711 | + loaded_sub_model, |
| 712 | + state_dict=loaded_sub_model.state_dict(), |
| 713 | + device_map=device_map, |
| 714 | + force_hooks=True, |
| 715 | + main_device=0, |
| 716 | + ) |
| 717 | + else: |
| 718 | + dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True) |
| 719 | + |
496 | 720 | return loaded_sub_model
|
497 | 721 |
|
498 | 722 |
|
|
0 commit comments