@@ -1486,8 +1486,14 @@ def _load_pretrained_model(
14861486            if  offload_state_dict  is  None :
14871487                offload_state_dict  =  True 
14881488
1489-         # Caching allocator warmup 
1490-         if  device_map  is  not   None :
1489+         # If a device map has been used, we can speedup the load time by warming up the device caching allocator. 
1490+         # If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a 
1491+         # lot of individual calls to device malloc). We can, however, preallocate the memory required by the 
1492+         # tensors using their expected shape and not performing any initialization of the memory (empty data). 
1493+         # When the actual device allocations happen, the allocator already has a pool of unused device memory 
1494+         # that it can re-use for faster loading of the model. 
1495+         # TODO: add support for warmup with hf_quantizer 
1496+         if  device_map  is  not   None  and  hf_quantizer  is  None :
14911497            expanded_device_map  =  _expand_device_map (device_map , expected_keys )
14921498            _caching_allocator_warmup (model , expanded_device_map , dtype )
14931499
@@ -1534,6 +1540,8 @@ def _load_pretrained_model(
15341540                    assign_to_params_buffers  =  check_support_param_buffer_assignment (model , state_dict )
15351541                error_msgs  +=  _load_state_dict_into_model (model , state_dict , assign_to_params_buffers )
15361542
1543+         # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is 
1544+         # required because we move tensors with non_blocking=True, which is slightly faster for model loading. 
15371545        empty_device_cache ()
15381546        device_synchronize ()
15391547
0 commit comments