Skip to content

Commit 8385f45

Browse files
committed
update
1 parent 7a9c448 commit 8385f45

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,14 +233,14 @@ def load_model_dict_into_meta(
233233
empty_state_dict = model.state_dict()
234234
expanded_device_map = {}
235235

236-
if device_map is not None:
237-
for param_name, param in state_dict.items():
238-
if param_name not in empty_state_dict:
239-
continue
240-
param_device = _determine_param_device(param_name, device_map)
241-
expanded_device_map[param_name] = param_device
242-
print(expanded_device_map)
243-
_caching_allocator_warmup(model, expanded_device_map, dtype)
236+
# if device_map is not None:
237+
# for param_name, param in state_dict.items():
238+
# if param_name not in empty_state_dict:
239+
# continue
240+
# param_device = _determine_param_device(param_name, device_map)
241+
# expanded_device_map[param_name] = param_device
242+
# print(expanded_device_map)
243+
# _caching_allocator_warmup(model, expanded_device_map, dtype)
244244

245245
for param_name, param in state_dict.items():
246246
if param_name not in empty_state_dict:

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,6 +1557,8 @@ def _find_mismatched_keys(
15571557

15581558
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
15591559

1560+
torch.cuda.synchronize()
1561+
15601562
if offload_index is not None and len(offload_index) > 0:
15611563
save_offload_index(offload_index, offload_folder)
15621564
offload_index = None

0 commit comments

Comments
 (0)