Skip to content

Commit 2c86116

Browse files
committed
up
1 parent 42f6c79 commit 2c86116

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

examples/models/llava/export_llava.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,32 @@ def quant_embedding(model):
193193

194194
quantized_token_embed = quant_embedding(llava.model_.language_model.model)
195195

196+
qval = quantized_token_embed.embedding.weight
197+
scale = quantized_token_embed.embedding.scales
198+
199+
qval_copy = quantized_token_embed_copy.embedding.weight.tensor_impl.get_plain()[0]
200+
scale_copy = quantized_token_embed_copy.embedding.weight.tensor_impl.get_plain()[1]
201+
zero_copy = quantized_token_embed_copy.embedding.weight.tensor_impl.get_plain()[2]
202+
203+
print("COPY TENSOR", quantized_token_embed_copy.embedding.weight)
204+
print("ORIGINAL DTYPE", quantized_token_embed.embedding.dtype)
205+
206+
print("COMPARING")
207+
print("qval_copy", qval_copy)
208+
print("qval", qval)
209+
print("MATCHING", (qval_copy == qval).to(torch.float32).mean())
210+
211+
print("scale_copy", scale_copy)
212+
print("scale", scale)
213+
print("ISCLOSE", torch.isclose(scale_copy, scale).to(torch.float32).mean())
214+
215+
print("zero_copy", zero_copy)
216+
print("ALL ZEROS", (zero_copy == 0).to(torch.float32).mean())
217+
218+
219+
220+
221+
196222
token_dim_1 = Dim("token_dim_1", min=2, max=llava.text_model_args.max_seq_len)
197223
dynamic_shapes = [{1: token_dim_1}]
198224
with torch.no_grad():

0 commit comments

Comments
 (0)