Skip to content

Commit 9303d73

Browse files
cccclaifacebook-github-bot
authored andcommitted
Remove custom annotation for kv attrs (#12526)
Summary: Remove the custom annotation, and copy the prefill kv quant attrs to decode kv quant attrs Rollback Plan: Differential Revision: D78375468
1 parent 4d7f9ca commit 9303d73

File tree

1 file changed

+26
-8
lines changed
  • examples/qualcomm/oss_scripts/llama

1 file changed

+26
-8
lines changed

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,8 @@ def get_example_inputs(self, use_kv_cache=True):
484484
def get_quant_attrs(self):
485485
return self.quant_attrs
486486

487+
def rtol(a, b):
488+
return abs(a - b) / max(abs(a), abs(b))
487489

488490
def compile(args, pte_filename, tokenizer):
489491
os.makedirs(args.artifact, exist_ok=True)
@@ -651,33 +653,49 @@ def permute(w, heads):
651653
custom_annotations = custom_annotations + (
652654
annotate_linear_16a8w_in_affine_layer,
653655
)
654-
kv_quant_attrs = {}
656+
# For kv_quant_attrs, kv_quant_attrs_list[0] is for prefill, and kv_quant_attrs_list[1] is for kv
657+
kv_quant_attrs_list = []
655658
for i, llama_instance in enumerate(llama_instance_list):
659+
print("Running llama instance: ", i)
660+
kv_quant_attrs = {}
656661
llama_instance.quantize(
657662
quant_dtype=quant_dtype,
658663
args=args,
659664
tokenizer=tokenizer,
660665
custom_annotations=custom_annotations,
661666
)
662667
# If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
663-
if i == 0 and args.model_mode in ["hybrid", "lookahead"]:
668+
if args.model_mode in ["hybrid", "lookahead"]:
664669
output_indices = 0
665670
for node in llama_instance.llama_graph_module.graph.nodes:
666671
if node.op == "output":
667672
for output in node.args[0]:
668673
kv_quant_attrs[output_indices] = output.args[1:]
669674
output_indices += 1
670675
break
671-
custom_annotations = custom_annotations + (
672-
partial(
673-
annotate_prefill_kv_output,
674-
kv_quant_attrs=kv_quant_attrs,
675-
),
676-
)
676+
kv_quant_attrs_list.append(kv_quant_attrs)
677677
llama_instance.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True
678678
llama_instance.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][
679679
"get_quant_io_dtype_fn"
680680
] = partial(llama_instance._tag_ios, fixed_point_type=fixed_point_type)
681+
682+
# Check prefill and decode graph is the same:
683+
assert len(llama_instance_list[0].llama_graph_module.graph.nodes) == len(llama_instance_list[1].llama_graph_module.graph.nodes)
684+
for prefill_node, decode_node in zip(llama_instance_list[0].llama_graph_module.graph.nodes, llama_instance_list[1].llama_graph_module.graph.nodes):
685+
assert prefill_node.op == decode_node.op
686+
assert prefill_node.target == decode_node.target
687+
688+
# Copy prefill quant_attrs to kv output quant_attrs
689+
output_indices = 0
690+
for prefill_node, decode_node in zip(llama_instance_list[0].llama_graph_module.graph.nodes, llama_instance_list[1].llama_graph_module.graph.nodes):
691+
if prefill_node.op == "output" and decode_node.op == "output":
692+
for prefill_output in prefill_node.args[0]:
693+
kv_quant_attrs[output_indices] = prefill_output.args[1:]
694+
rtol_value = rtol(prefill_output.args[1], decode_output.args[1])
695+
if rtol_value > 0.01:
696+
print(f"Warning: prefill and decode output quant_attrs are different, rtol is {rtol_value}")
697+
decode_node.args = tuple([decode_node.args[0]] + list(prefill_node.args[1:]))
698+
681699
end_quantize_ts = time.time()
682700
logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}")
683701

0 commit comments

Comments
 (0)