@@ -484,6 +484,8 @@ def get_example_inputs(self, use_kv_cache=True):
484484 def get_quant_attrs (self ):
485485 return self .quant_attrs
486486
487+ def rtol (a , b ):
488+ return abs (a - b ) / max (abs (a ), abs (b ))
487489
488490def compile (args , pte_filename , tokenizer ):
489491 os .makedirs (args .artifact , exist_ok = True )
@@ -651,33 +653,49 @@ def permute(w, heads):
651653 custom_annotations = custom_annotations + (
652654 annotate_linear_16a8w_in_affine_layer ,
653655 )
654- kv_quant_attrs = {}
656+ # For kv_quant_attrs, kv_quant_attrs_list[0] is for prefill, and kv_quant_attrs_list[1] is for kv
657+ kv_quant_attrs_list = []
655658 for i , llama_instance in enumerate (llama_instance_list ):
659+ print ("Running llama instance: " , i )
660+ kv_quant_attrs = {}
656661 llama_instance .quantize (
657662 quant_dtype = quant_dtype ,
658663 args = args ,
659664 tokenizer = tokenizer ,
660665 custom_annotations = custom_annotations ,
661666 )
662667 # If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
663- if i == 0 and args .model_mode in ["hybrid" , "lookahead" ]:
668+ if args .model_mode in ["hybrid" , "lookahead" ]:
664669 output_indices = 0
665670 for node in llama_instance .llama_graph_module .graph .nodes :
666671 if node .op == "output" :
667672 for output in node .args [0 ]:
668673 kv_quant_attrs [output_indices ] = output .args [1 :]
669674 output_indices += 1
670675 break
671- custom_annotations = custom_annotations + (
672- partial (
673- annotate_prefill_kv_output ,
674- kv_quant_attrs = kv_quant_attrs ,
675- ),
676- )
676+ kv_quant_attrs_list .append (kv_quant_attrs )
677677 llama_instance .passes_job [TagQuantIO ][QCOM_PASS_ACTIVATE_KEY ] = True
678678 llama_instance .passes_job [TagQuantIO ][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY ][
679679 "get_quant_io_dtype_fn"
680680 ] = partial (llama_instance ._tag_ios , fixed_point_type = fixed_point_type )
681+
682+ # Check prefill and decode graph is the same:
683+ assert len (llama_instance_list [0 ].llama_graph_module .graph .nodes ) == len (llama_instance_list [1 ].llama_graph_module .graph .nodes )
684+ for prefill_node , decode_node in zip (llama_instance_list [0 ].llama_graph_module .graph .nodes , llama_instance_list [1 ].llama_graph_module .graph .nodes ):
685+ assert prefill_node .op == decode_node .op
686+ assert prefill_node .target == decode_node .target
687+
688+ # Copy prefill quant_attrs to kv output quant_attrs
689+ output_indices = 0
690+ for prefill_node , decode_node in zip (llama_instance_list [0 ].llama_graph_module .graph .nodes , llama_instance_list [1 ].llama_graph_module .graph .nodes ):
691+ if prefill_node .op == "output" and decode_node .op == "output" :
692+ for prefill_output in prefill_node .args [0 ]:
693+ kv_quant_attrs [output_indices ] = prefill_output .args [1 :]
694+ rtol_value = rtol (prefill_output .args [1 ], decode_output .args [1 ])
695+ if rtol_value > 0.01 :
696+ print (f"Warning: prefill and decode output quant_attrs are different, rtol is { rtol_value } " )
697+ decode_node .args = tuple ([decode_node .args [0 ]] + list (prefill_node .args [1 :]))
698+
681699 end_quantize_ts = time .time ()
682700 logging .info (f"Time for quantizing: { end_quantize_ts - start_quantize_ts } " )
683701
0 commit comments