Skip to content

Commit 6940ded

Browse files
cccclaifacebook-github-bot
authored andcommitted
Remove custom annotation for kv attrs (#12526)
Summary: Pull Request resolved: #12526 Remove the custom annotation, and copy the prefill kv quant attrs to decode kv quant attrs Rollback Plan: Differential Revision: D78375468
1 parent 4d7f9ca commit 6940ded

File tree

1 file changed

+22
-16
lines changed
  • examples/qualcomm/oss_scripts/llama

1 file changed

+22
-16
lines changed

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,8 @@ def get_example_inputs(self, use_kv_cache=True):
484484
def get_quant_attrs(self):
485485
return self.quant_attrs
486486

487+
def rtol(a, b):
488+
return abs(a - b) / max(abs(a), abs(b))
487489

488490
def compile(args, pte_filename, tokenizer):
489491
os.makedirs(args.artifact, exist_ok=True)
@@ -651,33 +653,37 @@ def permute(w, heads):
651653
custom_annotations = custom_annotations + (
652654
annotate_linear_16a8w_in_affine_layer,
653655
)
654-
kv_quant_attrs = {}
656+
655657
for i, llama_instance in enumerate(llama_instance_list):
658+
print("Running llama instance: ", i)
656659
llama_instance.quantize(
657660
quant_dtype=quant_dtype,
658661
args=args,
659662
tokenizer=tokenizer,
660663
custom_annotations=custom_annotations,
661664
)
662-
# If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
663-
if i == 0 and args.model_mode in ["hybrid", "lookahead"]:
664-
output_indices = 0
665-
for node in llama_instance.llama_graph_module.graph.nodes:
666-
if node.op == "output":
667-
for output in node.args[0]:
668-
kv_quant_attrs[output_indices] = output.args[1:]
669-
output_indices += 1
670-
break
671-
custom_annotations = custom_annotations + (
672-
partial(
673-
annotate_prefill_kv_output,
674-
kv_quant_attrs=kv_quant_attrs,
675-
),
676-
)
677665
llama_instance.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True
678666
llama_instance.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][
679667
"get_quant_io_dtype_fn"
680668
] = partial(llama_instance._tag_ios, fixed_point_type=fixed_point_type)
669+
670+
# Check prefill and decode graph is the same:
671+
assert len(llama_instance_list[0].llama_graph_module.graph.nodes) == len(llama_instance_list[1].llama_graph_module.graph.nodes)
672+
for prefill_node, decode_node in zip(llama_instance_list[0].llama_graph_module.graph.nodes, llama_instance_list[1].llama_graph_module.graph.nodes):
673+
assert prefill_node.op == decode_node.op
674+
assert prefill_node.target == decode_node.target
675+
676+
# Copy prefill quant_attrs to kv output quant_attrs
677+
for prefill_node, decode_node in zip(llama_instance_list[0].llama_graph_module.graph.nodes, llama_instance_list[1].llama_graph_module.graph.nodes):
678+
if prefill_node.op == "output" and decode_node.op == "output":
679+
output_index = 0
680+
for prefill_output, decode_output in zip(prefill_node.args[0], decode_node.args[0]):
681+
rtol_value = rtol(prefill_output.args[1], decode_output.args[1])
682+
if rtol_value > 0.01:
683+
print(f"Warning: prefill and decode output at output_index={output_index} quant_attrs are different, rtol is {rtol_value}")
684+
decode_node.args = tuple([decode_node.args[0]] + list(prefill_node.args[1:]))
685+
output_index += 1
686+
681687
end_quantize_ts = time.time()
682688
logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}")
683689

0 commit comments

Comments
 (0)