1616
1717import importlib
1818import inspect
19+ import math
1920import os
2021from array import array
21- from collections import OrderedDict
22+ from collections import OrderedDict , defaultdict
2223from pathlib import Path
2324from typing import Dict , List , Optional , Union
2425from zipfile import is_zipfile
@@ -230,6 +231,16 @@ def load_model_dict_into_meta(
230231
231232 is_quantized = hf_quantizer is not None
232233 empty_state_dict = model .state_dict ()
234+ expanded_device_map = {}
235+
236+ if device_map is not None :
237+ for param_name , param in state_dict .items ():
238+ if param_name not in empty_state_dict :
239+ continue
240+ param_device = _determine_param_device (param_name , device_map )
241+ expanded_device_map [param_name ] = param_device
242+ print (expanded_device_map )
243+ _caching_allocator_warmup (model , expanded_device_map , dtype )
233244
234245 for param_name , param in state_dict .items ():
235246 if param_name not in empty_state_dict :
@@ -243,13 +254,13 @@ def load_model_dict_into_meta(
243254 if keep_in_fp32_modules is not None and any (
244255 module_to_keep_in_fp32 in param_name .split ("." ) for module_to_keep_in_fp32 in keep_in_fp32_modules
245256 ):
246- param = param .to (torch .float32 )
257+ param = param .to (torch .float32 , non_blocking = True )
247258 set_module_kwargs ["dtype" ] = torch .float32
248259 # For quantizers have save weights using torch.float8_e4m3fn
249260 elif hf_quantizer is not None and param .dtype == getattr (torch , "float8_e4m3fn" , None ):
250261 pass
251262 else :
252- param = param .to (dtype )
263+ param = param .to (dtype , non_blocking = True )
253264 set_module_kwargs ["dtype" ] = dtype
254265
255266 # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
@@ -265,7 +276,7 @@ def load_model_dict_into_meta(
265276
266277 if old_param is not None :
267278 if dtype is None :
268- param = param .to (old_param .dtype )
279+ param = param .to (old_param .dtype , non_blocking = True )
269280
270281 if old_param .is_contiguous ():
271282 param = param .contiguous ()
@@ -520,3 +531,29 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
520531 parsed_parameters [name ] = GGUFParameter (weights , quant_type = quant_type ) if is_gguf_quant else weights
521532
522533 return parsed_parameters
534+
535+
536+ # Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
537+ def _caching_allocator_warmup (model , expanded_device_map : Dict [str , torch .device ], dtype : torch .dtype ) -> None :
538+ """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
539+ device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
540+ which is actually the loading speed botteneck. Calling this function allows to cut the model loading time by a very
541+ large margin.
542+ """
543+ # Remove disk and cpu devices, and cast to proper torch.device
544+ accelerator_device_map = {
545+ param : torch .device (device )
546+ for param , device in expanded_device_map .items ()
547+ if str (device ) not in ["cpu" , "disk" ]
548+ }
549+ parameter_count = defaultdict (lambda : 0 )
550+ for param_name , device in accelerator_device_map .items ():
551+ try :
552+ param = model .get_parameter (param_name )
553+ except AttributeError :
554+ param = model .get_buffer (param_name )
555+ parameter_count [device ] += math .prod (param .shape )
556+
557+ # This will kick off the caching allocator to avoid having to Malloc afterwards
558+ for device , param_count in parameter_count .items ():
559+ _ = torch .empty (param_count , dtype = dtype , device = device , requires_grad = False )
0 commit comments