@@ -69,43 +69,50 @@ def main(config):
6969 for ppl_eval in eval_list :
7070 ppl = ppl_eval .eval (model )
7171 logger .info (f'{ ppl_eval .dataset } ppl : { ppl } ' )
72-
73- if not config .get ('calib' , False ):
74- blockwise_opt = ALGO_REGISTRY [config .quant .method ](
75- model ,
76- quant_config = config .quant ,
77- input = None ,
78- padding_mask = None ,
79- config = config
80- )
81- blockwise_opt .run_block_loop ()
82- dist .barrier ()
83- else :
84- dataset = BaseDataset (tokenizer .get_tokenizer (), config .calib , model .batch_process )
85- calib_data , padding_mask = dataset .get_calib_dataset ()
86- padding_side = getattr (tokenizer .get_tokenizer (), 'padding_side' , None )
87- model .collect_first_block_input (calib_data , padding_mask , padding_side , config .calib .type )
88- del calib_data
89- gc .collect ()
90- torch .cuda .empty_cache ()
91- if not config .get ('sparse' , False ):
72+ for modality in config .quant .get ('quant_objects' , ['language' ]):
73+ if not config .get ('calib' , False ):
9274 blockwise_opt = ALGO_REGISTRY [config .quant .method ](
9375 model ,
94- config .quant ,
95- model .get_first_block_input (),
96- model .get_padding_mask (),
97- config
76+ quant_config = config .quant ,
77+ input = None ,
78+ padding_mask = None ,
79+ config = config ,
80+ modality = modality ,
9881 )
82+ blockwise_opt .run_block_loop ()
83+ dist .barrier ()
9984 else :
100- blockwise_opt = ALGO_REGISTRY [config .sparse .method ](
101- model ,
102- config .sparse ,
103- model .get_first_block_input (),
104- model .get_padding_mask (),
105- config
106- )
107- blockwise_opt .run_block_loop ()
108- dist .barrier ()
85+ dataset = BaseDataset (tokenizer .get_tokenizer (), config .calib , model .batch_process )
86+ calib_data , padding_mask = dataset .get_calib_dataset ()
87+ padding_side = getattr (tokenizer .get_tokenizer (), 'padding_side' , None )
88+ model .collect_first_block_input (calib_data ,
89+ padding_mask ,
90+ padding_side ,
91+ config .calib .type ,
92+ modality )
93+ del calib_data
94+ gc .collect ()
95+ torch .cuda .empty_cache ()
96+ if not config .get ('sparse' , False ):
97+ blockwise_opt = ALGO_REGISTRY [config .quant .method ](
98+ model ,
99+ config .quant ,
100+ model .get_first_block_input (),
101+ model .get_padding_mask (),
102+ config ,
103+ modality
104+ )
105+ else :
106+ blockwise_opt = ALGO_REGISTRY [config .sparse .method ](
107+ model ,
108+ config .sparse ,
109+ model .get_first_block_input (),
110+ model .get_padding_mask (),
111+ config ,
112+ modality
113+ )
114+ blockwise_opt .run_block_loop ()
115+ dist .barrier ()
109116
110117 if int (os .environ ['RANK' ]) == 0 :
111118 if 'eval' in config and 'transformed' in config .eval .eval_pos :
0 commit comments