Skip to content

Commit a08f4ec

Browse files
committed
up
1 parent f1a94d0 commit a08f4ec

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

examples/models/llava/export_llava.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,32 @@ def quant_embedding(model):
199199

200200
quantized_token_embed = quant_embedding(llava.model_.language_model.model)
201201

202+
qval = quantized_token_embed.embedding.weight
203+
scale = quantized_token_embed.embedding.scales
204+
205+
qval_copy = quantized_token_embed_copy.embedding.weight.tensor_impl.get_plain()[0]
206+
scale_copy = quantized_token_embed_copy.embedding.weight.tensor_impl.get_plain()[1]
207+
zero_copy = quantized_token_embed_copy.embedding.weight.tensor_impl.get_plain()[2]
208+
209+
print("COPY TENSOR", quantized_token_embed_copy.embedding.weight)
210+
print("ORIGINAL DTYPE", quantized_token_embed.embedding.dtype)
211+
212+
print("COMPARING")
213+
print("qval_copy", qval_copy)
214+
print("qval", qval)
215+
print("MATCHING", (qval_copy == qval).to(torch.float32).mean())
216+
217+
print("scale_copy", scale_copy)
218+
print("scale", scale)
219+
print("ISCLOSE", torch.isclose(scale_copy, scale).to(torch.float32).mean())
220+
221+
print("zero_copy", zero_copy)
222+
print("ALL ZEROS", (zero_copy == 0).to(torch.float32).mean())
223+
224+
225+
226+
227+
202228
token_dim_1 = Dim("token_dim_1", min=2, max=llava.text_model_args.max_seq_len)
203229
dynamic_shapes = [{1: token_dim_1}]
204230
with torch.no_grad():

0 commit comments

Comments
 (0)