66
77import operator
88import re
9+ import time
910import warnings
1011from collections import OrderedDict
1112from typing import Any , Callable , Dict , FrozenSet , List , Optional , Tuple
@@ -740,17 +741,17 @@ def preprocess_binary(ctx_bin, compiler_specs):
740741 for k , v in type_map .items ():
741742 dtype_map .setdefault (v , k )
742743
743- qnn_in_order , executorch_in_order , executorch_out_order = [], [], []
744+ qnn_in_order , executorch_in_order , executorch_out_order = None , None , None
744745 if custom_info is not None :
745746 # since some context binaries might fail to open on host
746747 # if they are compiled with special flags:
747748 # e.g. weight sharing
748749 # use custom information here instead
749750 inputs = build_tensor (custom_info ["graph_inputs" ], dtype_map )
750751 outputs = build_tensor (custom_info ["graph_outputs" ], dtype_map )
751- qnn_in_order = custom_info [ "qnn_in_order" ]
752- executorch_in_order = custom_info [ "executorch_in_order" ]
753- executorch_out_order = custom_info [ "executorch_out_order" ]
752+ qnn_in_order = custom_info . get ( "qnn_in_order" , None )
753+ executorch_in_order = custom_info . get ( "executorch_in_order" , None )
754+ executorch_out_order = custom_info . get ( "executorch_out_order" , None )
754755 graph_name = custom_info ["graph_name" ]
755756 else :
756757 # get context-binary io tensor info through qnn manager
@@ -800,7 +801,9 @@ def draw_graph(title, path, graph_module: torch.fx.GraphModule):
800801
801802def generate_multi_graph_program (
802803 compiler_specs : List [CompileSpec ],
803- exported_programs : List [ExportedProgram ] = None ,
804+ processed_bytes : List [bytes ],
805+ input_nodes_dict : List [torch .fx .Node ] = None ,
806+ output_nodes_dict : List [torch .fx .Node ] = None ,
804807 backend_config : ExecutorchBackendConfig = None ,
805808 constant_methods : Optional [Dict [str , Any ]] = None ,
806809) -> ExecutorchProgramManager :
@@ -813,10 +816,6 @@ def generate_multi_graph_program(
813816 executorch_in_order ,
814817 executorch_out_order ,
815818 ) = ({}, {}, {}, {}, {})
816-
817- processed_bytes = [
818- prog .graph_module .lowered_module_0 .processed_bytes for prog in exported_programs
819- ]
820819 qnn_mgr = PyQnnManagerAdaptor .QnnManager (
821820 generate_qnn_executorch_option (compiler_specs ), processed_bytes
822821 )
@@ -829,38 +828,36 @@ def generate_multi_graph_program(
829828 graph_outputs [graph_name ] = qnn_mgr .GetGraphOutputs (graph_name )
830829
831830 # We need to obtain the order of the IOs to correctly map QNN with nn.module
832- for i , graph_name in enumerate (graph_names ):
833- # input
834- input_names = [
835- node .name
836- for node in exported_programs [i ].graph_module .graph .nodes
837- if node .op == "placeholder"
838- ]
839- qnn_input_names = [wrapper .GetName () for wrapper in graph_inputs [graph_name ]]
840- input_order_list = []
841- for input_name in input_names :
842- # e.g., input_0_tokens_0
843- pattern = rf"^input_(\d+)_({ input_name } )_(\d+)$"
844- for j in range (len (qnn_input_names )):
845- if re .match (pattern , qnn_input_names [j ]):
846- input_order_list .append (j )
847- break
848- assert (
849- len (input_order_list ) == len (input_names ) == len (qnn_input_names )
850- ), "Order list length is different from names"
851- executorch_in_order [graph_name ] = input_order_list
852- qnn_in_order [graph_name ] = sorted (
853- range (len (input_order_list )), key = lambda k : input_order_list [k ]
854- )
855-
856- # output
857- get_item_list = [
858- node
859- for node in exported_programs [i ].graph_module .graph .nodes
860- if node .op == "output"
861- ][0 ].args [0 ]
862- output_order_list = [item .args [1 ] for item in get_item_list ]
863- executorch_out_order [graph_name ] = output_order_list
831+ for graph_name in graph_names :
832+ if input_nodes_dict :
833+ # input
834+ input_names = [node .name for node in input_nodes_dict [graph_name ]]
835+ qnn_input_names = [
836+ wrapper .GetName () for wrapper in graph_inputs [graph_name ]
837+ ]
838+ # The input of intermideate module including call_function node
839+ # could not be reorder by node name
840+ if len (input_names ) == len (qnn_input_names ):
841+ input_order_list = []
842+ for input_name in input_names :
843+ # e.g., input_0_tokens_0
844+ pattern = rf"^input_(\d+)_({ input_name } )_(\d+)$"
845+ for j in range (len (qnn_input_names )):
846+ if re .match (pattern , qnn_input_names [j ]):
847+ input_order_list .append (j )
848+ break
849+ assert len (input_order_list ) == len (
850+ input_names
851+ ), "Order list length is different from names"
852+ executorch_in_order [graph_name ] = input_order_list
853+ qnn_in_order [graph_name ] = sorted (
854+ range (len (input_order_list )), key = lambda k : input_order_list [k ]
855+ )
856+ if output_nodes_dict :
857+ # output
858+ get_item_list = output_nodes_dict [graph_name ][0 ].args [0 ]
859+ output_order_list = [item .args [1 ] for item in get_item_list ]
860+ executorch_out_order [graph_name ] = output_order_list
864861
865862 qnn_mgr .Destroy ()
866863
@@ -869,15 +866,15 @@ def generate_multi_graph_program(
869866 bundle_progs = [
870867 from_context_binary (
871868 ctx_path = binary_info ,
872- op_name = f"loader_{ graph_name } " ,
869+ op_name = f"loader_{ graph_name } _ { int ( time . time ()) } " ,
873870 soc_model = compiler_options .soc_info .soc_model ,
874871 custom_info = {
875872 "graph_inputs" : graph_inputs [graph_name ],
876873 "graph_outputs" : graph_outputs [graph_name ],
877874 "graph_name" : graph_name ,
878- "qnn_in_order" : qnn_in_order [ graph_name ] ,
879- "executorch_in_order" : executorch_in_order [ graph_name ] ,
880- "executorch_out_order" : executorch_out_order [ graph_name ] ,
875+ "qnn_in_order" : qnn_in_order . get ( graph_name , None ) ,
876+ "executorch_in_order" : executorch_in_order . get ( graph_name , None ) ,
877+ "executorch_out_order" : executorch_out_order . get ( graph_name , None ) ,
881878 },
882879 )
883880 for graph_name in graph_names
@@ -900,9 +897,101 @@ def generate_multi_graph_program(
900897 break
901898
902899 edge_prog_mgr = edge_prog_mgr .to_backend (QnnPartitioner (compiler_specs ))
903- return edge_prog_mgr .to_executorch (
900+ exec_prog = edge_prog_mgr .to_executorch (
901+ config = backend_config or ExecutorchBackendConfig ()
902+ )
903+ return exec_prog , bundle_progs
904+
905+
906+ def generate_composite_llama_program (
907+ graph_names : List [str ],
908+ sample_inputs_list : List [Tuple [Any ]],
909+ lower_module_dict : Dict [str , List [LoweredBackendModule ]],
910+ call_delegate_node_name_dict : Dict [str , List [str ]],
911+ call_delegate_inputs_dict : Dict [str , List [Tuple [str , int | None ]]],
912+ outputs_dict : Dict [str , List [Tuple [str , int ]]],
913+ backend_config : ExecutorchBackendConfig = None ,
914+ constant_methods : Optional [Dict [str , Any ]] = None ,
915+ ) -> ExecutorchProgramManager :
916+ class CompositeLlamaModule (torch .nn .Module ):
917+ def __init__ (
918+ self ,
919+ lower_module_list ,
920+ call_delegate_node_name_list ,
921+ call_delegate_inputs_list ,
922+ outputs_list ,
923+ ) -> None :
924+ super ().__init__ ()
925+ self .lower_module_list = lower_module_list
926+ self .call_delegate_node_name_list = call_delegate_node_name_list
927+ self .call_delegate_inputs_list = call_delegate_inputs_list
928+ self .outputs_list = outputs_list
929+
930+ def reorder (
931+ self ,
932+ call_delegate_inputs : List [Tuple [str , int | None ]],
933+ module_inputs : dict [str , torch .Tensor ],
934+ all_ret : dict [str , torch .Tensor ],
935+ ) -> Tuple [torch .Tensor ]:
936+ ret = []
937+ for name , index in call_delegate_inputs :
938+ if index is not None :
939+ # Get tensor from previous results
940+ ret .append (all_ret [name ][index ])
941+ else :
942+ # Get tensor from the inputs of module
943+ ret .append (module_inputs [name ])
944+ return tuple (ret )
945+
946+ def forward (
947+ self ,
948+ tokens : torch .Tensor ,
949+ atten_mask : torch .Tensor ,
950+ input_pos : Optional [torch .Tensor ] = None ,
951+ * args ,
952+ ) -> Tuple [torch .Tensor ]:
953+ all_ret = {}
954+ module_input_dict = {
955+ "tokens" : tokens ,
956+ "atten_mask" : atten_mask ,
957+ "input_pos" : input_pos ,
958+ }
959+ for num , arg in enumerate (args ):
960+ module_input_dict [f"args_{ num } " ] = arg
961+ for lower_module , call_delegate_node_name , call_delegate_inputs in zip (
962+ self .lower_module_list ,
963+ self .call_delegate_node_name_list ,
964+ self .call_delegate_inputs_list ,
965+ ):
966+ inp = self .reorder (call_delegate_inputs , module_input_dict , all_ret )
967+ ret = lower_module (* inp )
968+ all_ret [call_delegate_node_name ] = ret
969+ llama_outputs = []
970+ for output_src_name , index in self .outputs_list :
971+ llama_outputs .append (all_ret [output_src_name ][index ])
972+ return tuple (llama_outputs )
973+
974+ progs_dict = {}
975+ for graph_name , sample_inputs in zip (graph_names , sample_inputs_list ):
976+ composite_llama_module = CompositeLlamaModule (
977+ lower_module_dict [graph_name ],
978+ call_delegate_node_name_dict [graph_name ],
979+ call_delegate_inputs_dict [graph_name ],
980+ outputs_dict [graph_name ],
981+ )
982+ prog = torch .export .export (composite_llama_module , sample_inputs )
983+ progs_dict [graph_name ] = prog
984+ # leverage ExecutorchProgramManager for generating pte with multi-methods
985+ edge_prog_mgr = to_edge (
986+ progs_dict ,
987+ constant_methods = constant_methods ,
988+ # do not alter name for custom op
989+ compile_config = EdgeCompileConfig (_check_ir_validity = False , _use_edge_ops = False ),
990+ )
991+ exec_prog = edge_prog_mgr .to_executorch (
904992 config = backend_config or ExecutorchBackendConfig ()
905993 )
994+ return exec_prog
906995
907996
908997def generate_htp_compiler_spec (
0 commit comments