File tree Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Original file line number Diff line number Diff line change @@ -199,6 +199,32 @@ def quant_embedding(model):
199199
200200 quantized_token_embed = quant_embedding (llava .model_ .language_model .model )
201201
202+ qval = quantized_token_embed .embedding .weight
203+ scale = quantized_token_embed .embedding .scales
204+
205+ qval_copy = quantized_token_embed_copy .embedding .weight .tensor_impl .get_plain ()[0 ]
206+ scale_copy = quantized_token_embed_copy .embedding .weight .tensor_impl .get_plain ()[1 ]
207+ zero_copy = quantized_token_embed_copy .embedding .weight .tensor_impl .get_plain ()[2 ]
208+
209+ print ("COPY TENSOR" , quantized_token_embed_copy .embedding .weight )
210+ print ("ORIGINAL DTYPE" , quantized_token_embed .embedding .dtype )
211+
212+ print ("COMPARING" )
213+ print ("qval_copy" , qval_copy )
214+ print ("qval" , qval )
215+ print ("MATCHING" , (qval_copy == qval ).to (torch .float32 ).mean ())
216+
217+ print ("scale_copy" , scale_copy )
218+ print ("scale" , scale )
219+ print ("ISCLOSE" , torch .isclose (scale_copy , scale ).to (torch .float32 ).mean ())
220+
221+ print ("zero_copy" , zero_copy )
222+ print ("ALL ZEROS" , (zero_copy == 0 ).to (torch .float32 ).mean ())
223+
224+
225+
226+
227+
202228 token_dim_1 = Dim ("token_dim_1" , min = 2 , max = llava .text_model_args .max_seq_len )
203229 dynamic_shapes = [{1 : token_dim_1 }]
204230 with torch .no_grad ():
You can’t perform that action at this time.
0 commit comments