Skip to content

Commit 19bd7a6

Browse files
committed
up
1 parent 0b9b46d commit 19bd7a6

File tree

3 files changed

+7
-15
lines changed

3 files changed

+7
-15
lines changed

.ci/scripts/test_llama.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ if [[ "${CUSTOM}" == "ON" ]]; then
233233
EXPORT_ARGS="${EXPORT_ARGS} --use_sdpa_with_kv_cache"
234234
fi
235235
if [[ "${QE}" == "ON" ]]; then
236-
EXPORT_ARGS="${EXPORT_ARGS} --embedding-quantize 8,1024"
236+
EXPORT_ARGS="${EXPORT_ARGS} --embedding-quantize 8,0"
237237
fi
238238
if [[ "${MPS}" == "ON" ]]; then
239239
EXPORT_ARGS="${EXPORT_ARGS} -kv -v --mps --disable_dynamic_shape"

examples/models/llama/source_transformation/quantize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,7 @@ def _quantize_embedding(model):
572572
torch.int4,
573573
torch.int8,
574574
], "Only 2, 4, or 8-bit embeddings are supported unless using torchao"
575+
print("GRAN", granularity)
575576
quantize_(
576577
model,
577578
IntxWeightOnlyConfig(

examples/models/llava/export_llava.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
replace_kv_cache_with_custom_kv_cache,
2525
)
2626
from executorch.examples.models.llama.source_transformation.quantize import (
27-
EmbeddingQuantHandler,
27+
get_quant_embedding_transform,
2828
get_quant_weight_transform,
2929
)
3030
from executorch.examples.models.llama.source_transformation.sdpa import (
@@ -38,7 +38,6 @@
3838
)
3939

4040
from executorch.exir.passes import MemoryPlanningPass
41-
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
4241
from executorch.exir.passes.sym_shape_eval_pass import (
4342
ConstraintBasedSymShapeEvalPass,
4443
HintBasedSymShapeEvalPass,
@@ -184,15 +183,9 @@ def forward(self, images):
184183

185184

186185
def export_token_embedding(llava, prompt):
187-
def quant_embedding(model):
188-
return EmbeddingQuantHandler(
189-
model,
190-
bitwidth=8,
191-
group_size=32,
192-
packed=False,
193-
).quantized_model()
194-
195-
quantized_token_embed = quant_embedding(llava.model_.language_model.model)
186+
quantized_token_embed = get_quant_embedding_transform(
187+
llava.model_.language_model.model
188+
)
196189
token_dim_1 = Dim("token_dim_1", min=2, max=llava.text_model_args.max_seq_len)
197190
dynamic_shapes = [{1: token_dim_1}]
198191
with torch.no_grad():
@@ -254,15 +247,13 @@ def export_all(llava_model: LlavaModel):
254247
executorch_program = lowered_and_edge.to_executorch(
255248
ExecutorchBackendConfig(
256249
extract_delegate_segments=True,
257-
passes=[
258-
QuantFusionPass(),
259-
],
260250
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
261251
sym_shape_eval_pass={
262252
"image_encoder": ConstraintBasedSymShapeEvalPass(),
263253
"text_model": ConstraintBasedSymShapeEvalPass(),
264254
"token_embedding": HintBasedSymShapeEvalPass(),
265255
},
256+
do_quant_fusion_and_const_prop=True,
266257
)
267258
)
268259
for execution_plan in executorch_program._emitter_output.program.execution_plan:

0 commit comments

Comments
 (0)