Skip to content

Commit ea446b1

Browse files
committed
add comment explanations
1 parent e364dfd commit ea446b1

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
431431
keep_in_fp32_modules=keep_in_fp32_modules,
432432
unexpected_keys=unexpected_keys,
433433
)
434+
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
435+
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
434436
empty_device_cache()
435437
device_synchronize()
436438
else:

src/diffusers/loaders/single_file_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,6 +1690,8 @@ def create_diffusers_clip_model_from_ldm(
16901690

16911691
if is_accelerate_available():
16921692
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1693+
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
1694+
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
16931695
empty_device_cache()
16941696
device_synchronize()
16951697
else:
@@ -2151,6 +2153,8 @@ def create_diffusers_t5_model_from_checkpoint(
21512153

21522154
if is_accelerate_available():
21532155
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
2156+
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
2157+
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
21542158
empty_device_cache()
21552159
device_synchronize()
21562160
else:

src/diffusers/models/modeling_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,8 +1486,14 @@ def _load_pretrained_model(
14861486
if offload_state_dict is None:
14871487
offload_state_dict = True
14881488

1489-
# Caching allocator warmup
1490-
if device_map is not None:
1489+
# If a device map has been used, we can speedup the load time by warming up the device caching allocator.
1490+
# If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a
1491+
# lot of individual calls to device malloc). We can, however, preallocate the memory required by the
1492+
# tensors using their expected shape and not performing any initialization of the memory (empty data).
1493+
# When the actual device allocations happen, the allocator already has a pool of unused device memory
1494+
# that it can re-use for faster loading of the model.
1495+
# TODO: add support for warmup with hf_quantizer
1496+
if device_map is not None and hf_quantizer is None:
14911497
expanded_device_map = _expand_device_map(device_map, expected_keys)
14921498
_caching_allocator_warmup(model, expanded_device_map, dtype)
14931499

@@ -1534,6 +1540,8 @@ def _load_pretrained_model(
15341540
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
15351541
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
15361542

1543+
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
1544+
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
15371545
empty_device_cache()
15381546
device_synchronize()
15391547

0 commit comments

Comments
 (0)