5050 fuse_layer_norms ,
5151 get_model_with_r1_r2 ,
5252)
53+
54+ from .source_transformation .attention import replace_attention_to_attention_sha
5355from .source_transformation .quantize import (
5456 get_quant_embedding_transform ,
5557 get_quant_weight_transform ,
@@ -174,6 +176,12 @@ def build_args_parser() -> argparse.ArgumentParser:
174176 help = "checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set." ,
175177 )
176178
179+ parser .add_argument (
180+ "--use_qnn_sha" ,
181+ action = "store_true" ,
182+ help = "Change multi head attention to multiple single head attention for qnn backend (Qualcomm)" ,
183+ )
184+
177185 parser .add_argument (
178186 "--calibration_tasks" ,
179187 nargs = "+" ,
@@ -642,7 +650,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
642650 )
643651 )
644652 # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
645- from executorch .backends .qualcomm .utils .utils import _transform
653+ from executorch .backends .qualcomm .utils .utils import _transform , tag_quant_io
646654
647655 # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
648656 _transform (builder_exported_to_edge .edge_manager .exported_program ())
@@ -654,7 +662,32 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
654662 builder_exported_to_edge .metadata ["get_n_layers" ],
655663 shares = args .num_sharding ,
656664 )
665+ from functools import partial
657666
667+ from executorch .backends .qualcomm .quantizer .custom_annotation import (
668+ get_custom_quant_ios_dtype ,
669+ )
670+ atten = builder_exported_to_edge .model .layers [0 ].attention
671+ if args .use_qnn_sha :
672+ cache_shape = torch .Size (
673+ (atten .max_batch_size , atten .max_seq_len , atten .head_dim )
674+ )
675+ else :
676+ cache_shape = torch .Size (
677+ (
678+ atten .max_batch_size ,
679+ atten .max_seq_len ,
680+ atten .n_kv_heads ,
681+ atten .head_dim ,
682+ )
683+ )
684+ tag_quant_io (
685+ builder_exported_to_edge .edge_manager .exported_program ().graph_module ,
686+ partial (
687+ get_custom_quant_ios_dtype ,
688+ cache_shape ,
689+ ),
690+ )
658691 logging .info ("Lowering model using following partitioner(s): " )
659692 for partitioner in partitioners :
660693 logging .info (f"--> { partitioner .__class__ .__name__ } " )
@@ -919,15 +952,27 @@ def _get_source_transforms( # noqa
919952 convert_linear_to_conv2d ,
920953 )
921954
922- transforms .append (replace_kv_cache_with_simple_kv_cache )
923- transforms .append (replace_sdpa_with_flex_sdpa )
924- transforms .append (replace_causal_mask )
925- transforms .append (replace_rms_norm_with_native_rms_norm )
926- if args .optimized_rotation_path :
927- transforms .append (fuse_layer_norms )
928- transforms .append (get_model_with_r1_r2 (args .optimized_rotation_path ))
929- # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
930- transforms .append (convert_linear_to_conv2d )
955+ if args .use_qnn_sha :
956+ if args .optimized_rotation_path :
957+ transforms .append (fuse_layer_norms )
958+ transforms .append (
959+ get_model_with_r1_r2 (args .optimized_rotation_path )
960+ )
961+ transforms .append (replace_attention_to_attention_sha )
962+ transforms .append (replace_causal_mask )
963+ transforms .append (replace_rms_norm_with_native_rms_norm )
964+ transforms .append (convert_linear_to_conv2d )
965+ else :
966+ transforms .append (replace_kv_cache_with_simple_kv_cache )
967+ transforms .append (replace_sdpa_with_flex_sdpa )
968+ transforms .append (replace_causal_mask )
969+ transforms .append (replace_rms_norm_with_native_rms_norm )
970+ if args .optimized_rotation_path :
971+ transforms .append (fuse_layer_norms )
972+ transforms .append (
973+ get_model_with_r1_r2 (args .optimized_rotation_path )
974+ )
975+ transforms .append (convert_linear_to_conv2d )
931976
932977 elif args .mps :
933978 # Currently mps doesn't support sdpa op, use the simpler decomposition
0 commit comments