8787 is_fp8_model ,
8888 is_hpex_available ,
8989 llm_load_model ,
90+ memory_monitor ,
9091 mv_module_from_gpu ,
9192 normalize_input ,
9293 set_amax_for_all_moe_layers ,
@@ -1025,6 +1026,7 @@ def quantize_and_save(
10251026 self .save_quantized (save_folder , format = format , inplace = inplace , ** kwargs )
10261027
10271028 folders .append (save_folder )
1029+ memory_monitor .log_summary ()
10281030
10291031 return model , folders
10301032
@@ -1513,6 +1515,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
15131515 all_to_quantized_module_names .remove (m .tmp_name )
15141516 if not self .immediate_saving :
15151517 mv_module_from_gpu (block )
1518+ memory_monitor .log_summary ()
15161519 pbar .update (1 )
15171520
15181521 pbar .close ()
@@ -1752,6 +1755,8 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
17521755 layer .cpu ()
17531756 layer_names .remove (layer_name )
17541757 if len (layer_names ) == 0 :
1758+ memory_monitor .update ()
1759+ memory_monitor .log_summary ()
17551760 return
17561761 q_layer_inputs = None
17571762 enable_quanted_input = self .enable_quanted_input
@@ -1770,7 +1775,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
17701775 if hasattr (self .model , "hf_device_map" ) and len (self .model .hf_device_map ) > 1 :
17711776 accelerate .hooks .remove_hook_from_submodules (
17721777 self .model
1773- ) ## self.model.hf_device_map has not been changed
1778+ ) # self.model.hf_device_map has not been changed
17741779 if not self .immediate_saving :
17751780 self .model = mv_module_from_gpu (self .model )
17761781 clear_memory (device_list = self .device_list )
@@ -1789,13 +1794,14 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
17891794 immediate_saving (self , m , name = layer_name , last_group = True )
17901795 del layer_input
17911796 clear_memory (q_layer_input , device_list = self .device_list )
1797+ memory_monitor .log_summary ()
17921798
17931799 @torch .no_grad ()
17941800 def _get_block_outputs (
17951801 self ,
17961802 block : torch .nn .Module ,
1797- input_ids : torch .Tensor ,
1798- input_others : torch .Tensor ,
1803+ input_ids : torch .Tensor | list [ torch . Tensor ] ,
1804+ input_others : torch .Tensor | dict ,
17991805 bs : int ,
18001806 device : Union [str , torch .device ],
18011807 cache_device : Union [str , torch .device ],
@@ -2805,7 +2811,7 @@ def _quantize_block(
28052811 f"quantized { len (quantized_layer_names )} /{ (len (quantized_layer_names ) + len (unquantized_layer_names ))} "
28062812 f"layers in the block, loss iter 0: { init_loss :.6f} -> iter { best_iter } : { last_loss :.6f} "
28072813 )
2808- logger . info ( dump_info )
2814+
28092815 if self .low_gpu_mem_usage :
28102816 clear_memory (device_list = self .device_list ) # clear cached memory during training
28112817 if len (unquantized_layer_names ) != 0 :
@@ -2833,6 +2839,8 @@ def _quantize_block(
28332839 mv_module_from_gpu (block )
28342840
28352841 clear_memory (input_ids )
2842+ memory_info_summary = memory_monitor .get_summary ()
2843+ logger .infoclean (dump_info + "," + memory_info_summary )
28362844
28372845 return q_outputs , output
28382846 else :
@@ -2841,6 +2849,8 @@ def _quantize_block(
28412849 if auto_offload :
28422850 mv_module_from_gpu (block )
28432851 clear_memory (input_ids )
2852+ memory_info_summary = memory_monitor .get_summary ()
2853+ logger .infoclean (dump_info + "," + memory_info_summary )
28442854
28452855 return None , output
28462856
@@ -3174,7 +3184,7 @@ def _sampling_inputs(
31743184 cls ,
31753185 input_ids : Union [list [torch .Tensor ], dict ],
31763186 input_others : dict ,
3177- indices : list [int ],
3187+ indices : list [int ] | torch . Tensor ,
31783188 seqlen : int ,
31793189 batch_dim : int = 0 ,
31803190 share_cache_keys : tuple = (),
0 commit comments