diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index db533986119..37d6bb50bb2 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -659,21 +659,7 @@ def permute(w, heads): tokenizer=tokenizer, custom_annotations=custom_annotations, ) - # If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later - if i == 0 and args.model_mode in ["hybrid", "lookahead"]: - output_indices = 0 - for node in llama_instance.llama_graph_module.graph.nodes: - if node.op == "output": - for output in node.args[0]: - kv_quant_attrs[output_indices] = output.args[1:] - output_indices += 1 - break - custom_annotations = custom_annotations + ( - partial( - annotate_prefill_kv_output, - kv_quant_attrs=kv_quant_attrs, - ), - ) + llama_instance.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True llama_instance.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ "get_quant_io_dtype_fn"