@@ -69,43 +69,46 @@ def main(config):
6969 for ppl_eval in eval_list :
7070 ppl = ppl_eval .eval (model )
7171 logger .info (f'{ ppl_eval .dataset } ppl : { ppl } ' )
72-
73- if not config .get ('calib' , False ):
74- blockwise_opt = ALGO_REGISTRY [config .quant .method ](
75- model ,
76- quant_config = config .quant ,
77- input = None ,
78- padding_mask = None ,
79- config = config
80- )
81- blockwise_opt .run_block_loop ()
82- dist .barrier ()
83- else :
84- dataset = BaseDataset (tokenizer .get_tokenizer (), config .calib , model .batch_process )
85- calib_data , padding_mask = dataset .get_calib_dataset ()
86- padding_side = getattr (tokenizer .get_tokenizer (), 'padding_side' , None )
87- model .collect_first_block_input (calib_data , padding_mask , padding_side , config .calib .type )
88- del calib_data
89- gc .collect ()
90- torch .cuda .empty_cache ()
91- if not config .get ('sparse' , False ):
72+ for modality in config .quant .get ('quant_objects' , ['language' ]):
73+ if not config .get ('calib' , False ):
9274 blockwise_opt = ALGO_REGISTRY [config .quant .method ](
9375 model ,
94- config .quant ,
95- model .get_first_block_input (),
96- model .get_padding_mask (),
97- config
76+ quant_config = config .quant ,
77+ input = None ,
78+ padding_mask = None ,
79+ config = config ,
80+ modality = modality ,
9881 )
82+ blockwise_opt .run_block_loop ()
83+ dist .barrier ()
9984 else :
100- blockwise_opt = ALGO_REGISTRY [config .sparse .method ](
101- model ,
102- config .sparse ,
103- model .get_first_block_input (),
104- model .get_padding_mask (),
105- config
106- )
107- blockwise_opt .run_block_loop ()
108- dist .barrier ()
85+ dataset = BaseDataset (tokenizer .get_tokenizer (), config .calib , model .batch_process )
86+ calib_data , padding_mask = dataset .get_calib_dataset ()
87+ padding_side = getattr (tokenizer .get_tokenizer (), 'padding_side' , None )
88+ model .collect_first_block_input (calib_data , padding_mask , padding_side , config .calib .type , modality )
89+ del calib_data
90+ gc .collect ()
91+ torch .cuda .empty_cache ()
92+ if not config .get ('sparse' , False ):
93+ blockwise_opt = ALGO_REGISTRY [config .quant .method ](
94+ model ,
95+ config .quant ,
96+ model .get_first_block_input (),
97+ model .get_padding_mask (),
98+ config ,
99+ modality
100+ )
101+ else :
102+ blockwise_opt = ALGO_REGISTRY [config .sparse .method ](
103+ model ,
104+ config .sparse ,
105+ model .get_first_block_input (),
106+ model .get_padding_mask (),
107+ config ,
108+ modality
109+ )
110+ blockwise_opt .run_block_loop ()
111+ dist .barrier ()
109112
110113 if int (os .environ ['RANK' ]) == 0 :
111114 if 'eval' in config and 'transformed' in config .eval .eval_pos :
0 commit comments