File tree Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Original file line number Diff line number Diff line change @@ -193,6 +193,32 @@ def quant_embedding(model):
193193
194194 quantized_token_embed = quant_embedding (llava .model_ .language_model .model )
195195
196+ qval = quantized_token_embed .embedding .weight
197+ scale = quantized_token_embed .embedding .scales
198+
199+ qval_copy = quantized_token_embed_copy .embedding .weight .tensor_impl .get_plain ()[0 ]
200+ scale_copy = quantized_token_embed_copy .embedding .weight .tensor_impl .get_plain ()[1 ]
201+ zero_copy = quantized_token_embed_copy .embedding .weight .tensor_impl .get_plain ()[2 ]
202+
203+ print ("COPY TENSOR" , quantized_token_embed_copy .embedding .weight )
204+ print ("ORIGINAL DTYPE" , quantized_token_embed .embedding .dtype )
205+
206+ print ("COMPARING" )
207+ print ("qval_copy" , qval_copy )
208+ print ("qval" , qval )
209+ print ("MATCHING" , (qval_copy == qval ).to (torch .float32 ).mean ())
210+
211+ print ("scale_copy" , scale_copy )
212+ print ("scale" , scale )
213+ print ("ISCLOSE" , torch .isclose (scale_copy , scale ).to (torch .float32 ).mean ())
214+
215+ print ("zero_copy" , zero_copy )
216+ print ("ALL ZEROS" , (zero_copy == 0 ).to (torch .float32 ).mean ())
217+
218+
219+
220+
221+
196222 token_dim_1 = Dim ("token_dim_1" , min = 2 , max = llava .text_model_args .max_seq_len )
197223 dynamic_shapes = [{1 : token_dim_1 }]
198224 with torch .no_grad ():
You can’t perform that action at this time.
0 commit comments