2626from executorch .examples .models .llama .source_transformation .quantize import (
2727 get_quant_embedding_transform ,
2828 get_quant_weight_transform ,
29- EmbeddingQuantHandler ,
3029)
3130from executorch .examples .models .llama .source_transformation .sdpa import (
3231 replace_sdpa_with_custom_op ,
@@ -178,51 +177,9 @@ def forward(self, images):
178177
179178
180179def export_token_embedding (llava , prompt ):
181- import copy
182- model_copy = copy .deepcopy (llava .model_ .language_model .model )
183- quantized_token_embed_copy = get_quant_embedding_transform ("8,32" )(
184- model_copy ,
180+ quantized_token_embed = get_quant_embedding_transform ("8,32" )(
181+ llava .model_ .language_model .model ,
185182 )
186- def quant_embedding (model ):
187- return EmbeddingQuantHandler (
188- model ,
189- bitwidth = 8 ,
190- group_size = 32 ,
191- packed = False ,
192- ).quantized_model ()
193-
194- quantized_token_embed = quant_embedding (llava .model_ .language_model .model )
195-
196- print ("GET ATTRS" , quantized_token_embed )
197- print ("GET ATTRS2" , quantized_token_embed .embed_tokens )
198-
199- qval = quantized_token_embed .embed_tokens .weight
200- scale = quantized_token_embed .embed_tokens .scales
201-
202- qval_copy = quantized_token_embed_copy .embed_tokens .weight .tensor_impl .get_plain ()[0 ]
203- scale_copy = quantized_token_embed_copy .embed_tokens .weight .tensor_impl .get_plain ()[1 ]
204- zero_copy = quantized_token_embed_copy .embed_tokens .weight .tensor_impl .get_plain ()[2 ]
205-
206- print ("COPY TENSOR" , quantized_token_embed_copy .embed_tokens .weight )
207- print ("ORIGINAL DTYPE" , quantized_token_embed .embed_tokens .dtype )
208-
209- print ("COMPARING" )
210- print ("qval_copy" , qval_copy )
211- print ("qval" , qval )
212- print ("MATCHING" , (qval_copy == qval ).to (torch .float32 ).mean ())
213- print ("MAX DIFF" , (qval_copy .to (torch .int32 ) - qval .to (torch .int32 )).abs ().max ())
214-
215- print ("scale_copy" , scale_copy )
216- print ("scale" , scale )
217- print ("ISCLOSE" , torch .isclose (scale_copy , scale ).to (torch .float32 ).mean ())
218-
219- print ("zero_copy" , zero_copy )
220- print ("ALL ZEROS" , (zero_copy == 0 ).to (torch .float32 ).mean ())
221-
222-
223-
224-
225-
226183 token_dim_1 = Dim ("token_dim_1" , min = 2 , max = llava .text_model_args .max_seq_len )
227184 dynamic_shapes = [{1 : token_dim_1 }]
228185 with torch .no_grad ():
@@ -232,16 +189,7 @@ def quant_embedding(model):
232189 dynamic_shapes = dynamic_shapes ,
233190 strict = True ,
234191 )
235- token_embedding_ep_copy = torch .export .export (
236- quantized_token_embed_copy .embed_tokens ,
237- (prompt ,),
238- dynamic_shapes = dynamic_shapes ,
239- strict = True ,
240- )
241-
242- print ("token_embedding_ep_copy" , token_embedding_ep_copy )
243- print ("token_embedding_ep" , token_embedding_ep )
244- return token_embedding_ep_copy
192+ return token_embedding_ep
245193
246194
247195def export_all (llava_model : LlavaModel ):
@@ -302,7 +250,6 @@ def export_all(llava_model: LlavaModel):
302250 do_quant_fusion_and_const_prop = True ,
303251 )
304252 )
305- logging .info ("TOKEN EMBEDDING PROG" , str (executorch_program .exported_program ("token_embedding" )))
306253 for execution_plan in executorch_program ._emitter_output .program .execution_plan :
307254 logging .info (
308255 f"Required memory for activation in bytes: { execution_plan .non_const_buffer_sizes } "
0 commit comments