Skip to content

Commit f1a94d0

Browse files
committed
up
1 parent c29a9f6 commit f1a94d0

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

examples/models/llava/export_llava.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,11 @@ def forward(self, images):
184184

185185

186186
def export_token_embedding(llava, prompt):
187-
# quantized_token_embed = get_quant_embedding_transform("8,32")(
188-
# llava.model_.language_model.model
189-
# )
187+
import copy
188+
model_copy = copy.deepcopy(llava.model_.language_model.model)
189+
quantized_token_embed_copy = get_quant_embedding_transform("8,32")(
190+
model_copy,
191+
)
190192
def quant_embedding(model):
191193
return EmbeddingQuantHandler(
192194
model,
@@ -206,6 +208,15 @@ def quant_embedding(model):
206208
dynamic_shapes=dynamic_shapes,
207209
strict=True,
208210
)
211+
token_embedding_ep_copy = torch.export.export(
212+
quantized_token_embed_copy.embed_tokens,
213+
(prompt,),
214+
dynamic_shapes=dynamic_shapes,
215+
strict=True,
216+
)
217+
218+
print("token_embedding_ep_copy", token_embedding_ep_copy)
219+
print("token_embedding_ep", token_embedding_ep)
209220
return token_embedding_ep
210221

211222

0 commit comments

Comments
 (0)