Skip to content

Commit e0e8c58

Browse files
authored
[Core] separate the loading utilities in modeling similar to pipelines. (#7943)
separate the loading utilities in modeling similar to pipelines.
1 parent cbea5d1 commit e0e8c58

File tree

2 files changed

+155
-115
lines changed

2 files changed

+155
-115
lines changed
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Inc. team.
3+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import inspect
18+
import os
19+
from collections import OrderedDict
20+
from typing import List, Optional, Union
21+
22+
import safetensors
23+
import torch
24+
25+
from ..utils import (
26+
SAFETENSORS_FILE_EXTENSION,
27+
is_accelerate_available,
28+
is_torch_version,
29+
logging,
30+
)
31+
32+
33+
logger = logging.get_logger(__name__)
34+
35+
36+
if is_accelerate_available():
37+
from accelerate import infer_auto_device_map
38+
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
39+
40+
41+
# Adapted from `transformers` (see modeling_utils.py)
42+
def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype):
43+
if isinstance(device_map, str):
44+
no_split_modules = model._get_no_split_modules(device_map)
45+
device_map_kwargs = {"no_split_module_classes": no_split_modules}
46+
47+
if device_map != "sequential":
48+
max_memory = get_balanced_memory(
49+
model,
50+
dtype=torch_dtype,
51+
low_zero=(device_map == "balanced_low_0"),
52+
max_memory=max_memory,
53+
**device_map_kwargs,
54+
)
55+
else:
56+
max_memory = get_max_memory(max_memory)
57+
58+
device_map_kwargs["max_memory"] = max_memory
59+
device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
60+
61+
return device_map
62+
63+
64+
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
65+
"""
66+
Reads a checkpoint file, returning properly formatted errors if they arise.
67+
"""
68+
try:
69+
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
70+
if file_extension == SAFETENSORS_FILE_EXTENSION:
71+
return safetensors.torch.load_file(checkpoint_file, device="cpu")
72+
else:
73+
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
74+
return torch.load(
75+
checkpoint_file,
76+
map_location="cpu",
77+
**weights_only_kwarg,
78+
)
79+
except Exception as e:
80+
try:
81+
with open(checkpoint_file) as f:
82+
if f.read().startswith("version"):
83+
raise OSError(
84+
"You seem to have cloned a repository without having git-lfs installed. Please install "
85+
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
86+
"you cloned."
87+
)
88+
else:
89+
raise ValueError(
90+
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
91+
"model. Make sure you have saved the model properly."
92+
) from e
93+
except (UnicodeDecodeError, ValueError):
94+
raise OSError(
95+
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
96+
)
97+
98+
99+
def load_model_dict_into_meta(
100+
model,
101+
state_dict: OrderedDict,
102+
device: Optional[Union[str, torch.device]] = None,
103+
dtype: Optional[Union[str, torch.dtype]] = None,
104+
model_name_or_path: Optional[str] = None,
105+
) -> List[str]:
106+
device = device or torch.device("cpu")
107+
dtype = dtype or torch.float32
108+
109+
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
110+
111+
unexpected_keys = []
112+
empty_state_dict = model.state_dict()
113+
for param_name, param in state_dict.items():
114+
if param_name not in empty_state_dict:
115+
unexpected_keys.append(param_name)
116+
continue
117+
118+
if empty_state_dict[param_name].shape != param.shape:
119+
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
120+
raise ValueError(
121+
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
122+
)
123+
124+
if accepts_dtype:
125+
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
126+
else:
127+
set_module_tensor_to_device(model, param_name, device, value=param)
128+
return unexpected_keys
129+
130+
131+
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
132+
# Convert old format to new format if needed from a PyTorch state_dict
133+
# copy state_dict so _load_from_state_dict can modify it
134+
state_dict = state_dict.copy()
135+
error_msgs = []
136+
137+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
138+
# so we need to apply the function recursively.
139+
def load(module: torch.nn.Module, prefix: str = ""):
140+
args = (state_dict, prefix, {}, True, [], [], error_msgs)
141+
module._load_from_state_dict(*args)
142+
143+
for name, child in module._modules.items():
144+
if child is not None:
145+
load(child, prefix + name + ".")
146+
147+
load(model_to_load)
148+
149+
return error_msgs

src/diffusers/models/modeling_utils.py

Lines changed: 6 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from ..utils import (
3434
CONFIG_NAME,
3535
FLAX_WEIGHTS_NAME,
36-
SAFETENSORS_FILE_EXTENSION,
3736
SAFETENSORS_WEIGHTS_NAME,
3837
WEIGHTS_NAME,
3938
_add_variant,
@@ -44,6 +43,12 @@
4443
logging,
4544
)
4645
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
46+
from .model_loading_utils import (
47+
_determine_device_map,
48+
_load_state_dict_into_model,
49+
load_model_dict_into_meta,
50+
load_state_dict,
51+
)
4752

4853

4954
logger = logging.get_logger(__name__)
@@ -57,9 +62,6 @@
5762

5863
if is_accelerate_available():
5964
import accelerate
60-
from accelerate import infer_auto_device_map
61-
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
62-
from accelerate.utils.versions import is_torch_version
6365

6466

6567
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
@@ -100,117 +102,6 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
100102
return first_tuple[1].dtype
101103

102104

103-
# Adapted from `transformers` (see modeling_utils.py)
104-
def _determine_device_map(model: "ModelMixin", device_map, max_memory, torch_dtype):
105-
if isinstance(device_map, str):
106-
no_split_modules = model._get_no_split_modules(device_map)
107-
device_map_kwargs = {"no_split_module_classes": no_split_modules}
108-
109-
if device_map != "sequential":
110-
max_memory = get_balanced_memory(
111-
model,
112-
dtype=torch_dtype,
113-
low_zero=(device_map == "balanced_low_0"),
114-
max_memory=max_memory,
115-
**device_map_kwargs,
116-
)
117-
else:
118-
max_memory = get_max_memory(max_memory)
119-
120-
device_map_kwargs["max_memory"] = max_memory
121-
device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
122-
123-
return device_map
124-
125-
126-
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
127-
"""
128-
Reads a checkpoint file, returning properly formatted errors if they arise.
129-
"""
130-
try:
131-
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
132-
if file_extension == SAFETENSORS_FILE_EXTENSION:
133-
return safetensors.torch.load_file(checkpoint_file, device="cpu")
134-
else:
135-
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
136-
return torch.load(
137-
checkpoint_file,
138-
map_location="cpu",
139-
**weights_only_kwarg,
140-
)
141-
except Exception as e:
142-
try:
143-
with open(checkpoint_file) as f:
144-
if f.read().startswith("version"):
145-
raise OSError(
146-
"You seem to have cloned a repository without having git-lfs installed. Please install "
147-
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
148-
"you cloned."
149-
)
150-
else:
151-
raise ValueError(
152-
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
153-
"model. Make sure you have saved the model properly."
154-
) from e
155-
except (UnicodeDecodeError, ValueError):
156-
raise OSError(
157-
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
158-
)
159-
160-
161-
def load_model_dict_into_meta(
162-
model,
163-
state_dict: OrderedDict,
164-
device: Optional[Union[str, torch.device]] = None,
165-
dtype: Optional[Union[str, torch.dtype]] = None,
166-
model_name_or_path: Optional[str] = None,
167-
) -> List[str]:
168-
device = device or torch.device("cpu")
169-
dtype = dtype or torch.float32
170-
171-
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
172-
173-
unexpected_keys = []
174-
empty_state_dict = model.state_dict()
175-
for param_name, param in state_dict.items():
176-
if param_name not in empty_state_dict:
177-
unexpected_keys.append(param_name)
178-
continue
179-
180-
if empty_state_dict[param_name].shape != param.shape:
181-
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
182-
raise ValueError(
183-
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
184-
)
185-
186-
if accepts_dtype:
187-
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
188-
else:
189-
set_module_tensor_to_device(model, param_name, device, value=param)
190-
return unexpected_keys
191-
192-
193-
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
194-
# Convert old format to new format if needed from a PyTorch state_dict
195-
# copy state_dict so _load_from_state_dict can modify it
196-
state_dict = state_dict.copy()
197-
error_msgs = []
198-
199-
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
200-
# so we need to apply the function recursively.
201-
def load(module: torch.nn.Module, prefix: str = ""):
202-
args = (state_dict, prefix, {}, True, [], [], error_msgs)
203-
module._load_from_state_dict(*args)
204-
205-
for name, child in module._modules.items():
206-
if child is not None:
207-
load(child, prefix + name + ".")
208-
209-
load(model_to_load)
210-
211-
return error_msgs
212-
213-
214105
class ModelMixin(torch.nn.Module, PushToHubMixin):
215106
r"""
216107
Base class for all models.

0 commit comments

Comments
 (0)