@@ -726,23 +726,29 @@ def _caching_allocator_warmup(
726726    very large margin. 
727727    """ 
728728    factor  =  2  if  hf_quantizer  is  None  else  hf_quantizer .get_cuda_warm_up_factor ()
729-     # Remove disk and cpu devices, and cast to proper torch.device 
729+ 
730+     # Keep only accelerator devices 
730731    accelerator_device_map  =  {
731732        param : torch .device (device )
732733        for  param , device  in  expanded_device_map .items ()
733734        if  str (device ) not  in "cpu" , "disk" ]
734735    }
735-     total_byte_count  =  defaultdict (lambda : 0 )
736+     if  not  accelerator_device_map :
737+         return 
738+ 
739+     elements_per_device  =  defaultdict (int )
736740    for  param_name , device  in  accelerator_device_map .items ():
737741        try :
738-             param  =  model .get_parameter (param_name )
742+             p  =  model .get_parameter (param_name )
739743        except  AttributeError :
740-             param  =  model .get_buffer (param_name )
741-         # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules` 
742-         param_byte_count  =  param .numel () *  param .element_size ()
744+             try :
745+                 p  =  model .get_buffer (param_name )
746+             except  AttributeError :
747+                 raise  AttributeError (f"Parameter or buffer with name={ param_name }  )
743748        # TODO: account for TP when needed. 
744-         total_byte_count [device ] +=  param_byte_count 
749+         elements_per_device [device ] +=  p . numel () 
745750
746751    # This will kick off the caching allocator to avoid having to Malloc afterwards 
747-     for  device , byte_count  in  total_byte_count .items ():
748-         _  =  torch .empty (byte_count  //  factor , dtype = dtype , device = device , requires_grad = False )
752+     for  device , elem_count  in  elements_per_device .items ():
753+         warmup_elems  =  max (1 , elem_count  //  factor )
754+         _  =  torch .empty (warmup_elems , dtype = dtype , device = device , requires_grad = False )
0 commit comments