Skip to content

Commit dfff847

Browse files
committed
up
1 parent 73c08f5 commit dfff847

File tree

5 files changed

+10
-61
lines changed

5 files changed

+10
-61
lines changed

examples/apple/coreml/llama/export.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def main() -> None:
222222
],
223223
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
224224
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
225+
do_quant_fusion_and_const_prop=True,
225226
)
226227
)
227228

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1342,7 +1342,8 @@ def _get_source_transforms( # noqa
13421342
"""
13431343
transforms.append(
13441344
get_quant_embedding_transform(
1345-
embedding_quantize, use_shared_embedding, checkpoint_dtype
1345+
embedding_quantize,
1346+
use_shared_embedding,
13461347
)
13471348
)
13481349

examples/models/llama/source_transformation/quantize.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
510510
self.precision,
511511
)
512512

513+
513514
#########################################################################
514515
##### embedding table quantization ######
515516

@@ -734,7 +735,6 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
734735
def get_quant_embedding_transform(
735736
embedding_quantize: str,
736737
use_shared_embedding: bool = False,
737-
dtype_override: Optional[DType] = None,
738738
):
739739
use_torchao = embedding_quantize.startswith("torchao:")
740740
if use_torchao:
@@ -783,13 +783,12 @@ def _torchao_embedding_quantizer(model):
783783

784784
return _torchao_embedding_quantizer
785785

786-
def _quantize_embedding(model):
786+
def _embedding_quantizer(model):
787787
assert weight_dtype in [
788788
torch.int2,
789789
torch.int4,
790790
torch.int8,
791791
], "Only 2, 4, or 8-bit embeddings are supported unless using torchao"
792-
print("GRAN", granularity)
793792
quantize_(
794793
model,
795794
IntxWeightOnlyConfig(
@@ -801,7 +800,7 @@ def _quantize_embedding(model):
801800
)
802801
return model
803802

804-
return _quantize_embedding
803+
return _embedding_quantizer
805804

806805

807806
def get_quant_weight_transform(

examples/models/llava/export_llava.py

Lines changed: 3 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from executorch.examples.models.llama.source_transformation.quantize import (
2727
get_quant_embedding_transform,
2828
get_quant_weight_transform,
29-
EmbeddingQuantHandler,
3029
)
3130
from executorch.examples.models.llama.source_transformation.sdpa import (
3231
replace_sdpa_with_custom_op,
@@ -178,51 +177,9 @@ def forward(self, images):
178177

179178

180179
def export_token_embedding(llava, prompt):
181-
import copy
182-
model_copy = copy.deepcopy(llava.model_.language_model.model)
183-
quantized_token_embed_copy = get_quant_embedding_transform("8,32")(
184-
model_copy,
180+
quantized_token_embed = get_quant_embedding_transform("8,32")(
181+
llava.model_.language_model.model,
185182
)
186-
def quant_embedding(model):
187-
return EmbeddingQuantHandler(
188-
model,
189-
bitwidth=8,
190-
group_size=32,
191-
packed=False,
192-
).quantized_model()
193-
194-
quantized_token_embed = quant_embedding(llava.model_.language_model.model)
195-
196-
print("GET ATTRS", quantized_token_embed)
197-
print("GET ATTRS2", quantized_token_embed.embed_tokens)
198-
199-
qval = quantized_token_embed.embed_tokens.weight
200-
scale = quantized_token_embed.embed_tokens.scales
201-
202-
qval_copy = quantized_token_embed_copy.embed_tokens.weight.tensor_impl.get_plain()[0]
203-
scale_copy = quantized_token_embed_copy.embed_tokens.weight.tensor_impl.get_plain()[1]
204-
zero_copy = quantized_token_embed_copy.embed_tokens.weight.tensor_impl.get_plain()[2]
205-
206-
print("COPY TENSOR", quantized_token_embed_copy.embed_tokens.weight)
207-
print("ORIGINAL DTYPE", quantized_token_embed.embed_tokens.dtype)
208-
209-
print("COMPARING")
210-
print("qval_copy", qval_copy)
211-
print("qval", qval)
212-
print("MATCHING", (qval_copy == qval).to(torch.float32).mean())
213-
print("MAX DIFF", (qval_copy.to(torch.int32) - qval.to(torch.int32)).abs().max())
214-
215-
print("scale_copy", scale_copy)
216-
print("scale", scale)
217-
print("ISCLOSE", torch.isclose(scale_copy, scale).to(torch.float32).mean())
218-
219-
print("zero_copy", zero_copy)
220-
print("ALL ZEROS", (zero_copy == 0).to(torch.float32).mean())
221-
222-
223-
224-
225-
226183
token_dim_1 = Dim("token_dim_1", min=2, max=llava.text_model_args.max_seq_len)
227184
dynamic_shapes = [{1: token_dim_1}]
228185
with torch.no_grad():
@@ -232,16 +189,7 @@ def quant_embedding(model):
232189
dynamic_shapes=dynamic_shapes,
233190
strict=True,
234191
)
235-
token_embedding_ep_copy = torch.export.export(
236-
quantized_token_embed_copy.embed_tokens,
237-
(prompt,),
238-
dynamic_shapes=dynamic_shapes,
239-
strict=True,
240-
)
241-
242-
print("token_embedding_ep_copy", token_embedding_ep_copy)
243-
print("token_embedding_ep", token_embedding_ep)
244-
return token_embedding_ep_copy
192+
return token_embedding_ep
245193

246194

247195
def export_all(llava_model: LlavaModel):
@@ -302,7 +250,6 @@ def export_all(llava_model: LlavaModel):
302250
do_quant_fusion_and_const_prop=True,
303251
)
304252
)
305-
logging.info("TOKEN EMBEDDING PROG", str(executorch_program.exported_program("token_embedding")))
306253
for execution_plan in executorch_program._emitter_output.program.execution_plan:
307254
logging.info(
308255
f"Required memory for activation in bytes: {execution_plan.non_const_buffer_sizes}"

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ def lowering_modules(
440440
alloc_graph_output=False,
441441
),
442442
extract_delegate_segments=True,
443+
do_quant_fusion_and_const_prop=True,
443444
)
444445
with torch.no_grad():
445446
# backend option

0 commit comments

Comments
 (0)