Skip to content

Commit 37b6c1f

Browse files
authored
Update modeling_utils.py
1 parent 0d9d98f commit 37b6c1f

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import inspect
1919
import itertools
2020
import json
21+
import time
2122
import os
2223
import re
2324
from collections import OrderedDict
@@ -538,6 +539,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
538539
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
539540
```
540541
"""
542+
init = time.time()
541543
cache_dir = kwargs.pop("cache_dir", None)
542544
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
543545
force_download = kwargs.pop("force_download", False)
@@ -637,6 +639,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
637639
}
638640

639641
# load config
642+
start = time.time()
640643
config, unused_kwargs, commit_hash = cls.load_config(
641644
config_path,
642645
cache_dir=cache_dir,
@@ -651,6 +654,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
651654
user_agent=user_agent,
652655
**kwargs,
653656
)
657+
peint(time.time()-start,"config load")
654658
# no in-place modification of the original config.
655659
config = copy.deepcopy(config)
656660

@@ -816,6 +820,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
816820
)
817821

818822
if low_cpu_mem_usage:
823+
print("low cpu mem use")
819824
# Instantiate model with empty weights
820825
with accelerate.init_empty_weights():
821826
model = cls.from_config(config, **unused_kwargs)
@@ -829,6 +834,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
829834
if device_map is None and not is_sharded:
830835
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
831836
# It would error out during the `validate_environment()` call above in the absence of cuda.
837+
832838
is_quant_method_bnb = (
833839
getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
834840
)
@@ -874,6 +880,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
874880
else: # else let accelerate handle loading and dispatching.
875881
# Load weights and dispatch according to the device_map
876882
# by default the device_map is None and the weights are loaded on the CPU
883+
print("accelerate load checkpoint")
877884
force_hook = True
878885
device_map = _determine_device_map(
879886
model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
@@ -935,11 +942,23 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
935942
"error_msgs": [],
936943
}
937944
else:
945+
print("did from_config")
946+
print(cls)
947+
torch.cuda.synchronzie("cuda")
948+
start = time.time()
938949
model = cls.from_config(config, **unused_kwargs)
950+
torch.cuda.synchronize("cuda")
951+
print(time.time()-start,"from_config")
939952

953+
torch.cuda.synchronzie("cuda")
954+
start = time.time()
940955
state_dict = load_state_dict(model_file, variant=variant)
956+
torch.cuda.synchronize("cuda")
957+
print(time.time()-start,"load_state_dict")
941958
model._convert_deprecated_attention_blocks(state_dict)
942959

960+
torch.cuda.synchronzie("cuda")
961+
start = time.time()
943962
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
944963
model,
945964
state_dict,
@@ -948,13 +967,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
948967
ignore_mismatched_sizes=ignore_mismatched_sizes,
949968
)
950969

970+
torch.cuda.synchronize("cuda")
971+
print(time.time()-start,"load_ pretrained_model")
972+
951973
loading_info = {
952974
"missing_keys": missing_keys,
953975
"unexpected_keys": unexpected_keys,
954976
"mismatched_keys": mismatched_keys,
955977
"error_msgs": error_msgs,
956978
}
957979

980+
torch.cuda.synchronize("cuda")
981+
start= time.time()
958982
if hf_quantizer is not None:
959983
hf_quantizer.postprocess_model(model)
960984
model.hf_quantizer = hf_quantizer
@@ -975,11 +999,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
975999
else:
9761000
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
9771001

1002+
torch.cuda.synchronize("cuda")
1003+
print(time.time()-start,"to device")
1004+
9781005
# Set model in evaluation mode to deactivate DropOut modules by default
9791006
model.eval()
9801007
if output_loading_info:
9811008
return model, loading_info
9821009

1010+
print(time.time()-init,"total")
9831011
return model
9841012

9851013
# Adapted from `transformers`.

0 commit comments

Comments
 (0)