@@ -317,60 +317,126 @@ def get_decomp_table(passes_job) -> Dict[torch._ops.OperatorBase, Callable]:
317317
318318
319319def to_edge_transform_and_lower_to_qnn (
320- module : Union [torch .nn .Module , torch .fx .GraphModule ],
321- inputs : Tuple [torch .Tensor ],
322- compiler_specs : List [CompileSpec ],
320+ module : Union [
321+ torch .nn .Module ,
322+ torch .fx .GraphModule ,
323+ Dict [str , torch .nn .Module ],
324+ Dict [str , torch .fx .GraphModule ],
325+ ],
326+ inputs : Union [Tuple [torch .Tensor ], Dict [str , Tuple [torch .Tensor ]]],
327+ compiler_specs : Union [List [Any ], Dict [str , List [Any ]]],
323328 constant_methods : Optional [Dict [str , Any ]] = None ,
324329 dynamic_shapes : Optional [Dict ] = None ,
325330 dep_table : Optional [Dict ] = None ,
326- passes_job : Optional [OrderedDict ] = None ,
331+ passes_job : Optional [Union [ OrderedDict , Dict [ str , OrderedDict ]] ] = None ,
327332 skip_node_id_set : Optional [set ] = None ,
328333 skip_node_op_set : Optional [set ] = None ,
329334) -> EdgeProgramManager :
330335 """
331- Transforms and lowers a given PyTorch module to QNN backend.
336+ Transforms and lowers a given PyTorch module to the QNN backend.
332337
333338 Args:
334- module (Union[torch.nn.Module, torch.fx.GraphModule]): The PyTorch module or fx.GraphModule to be transformed.
335- inputs (Tuple[torch.Tensor]): The input tensors for the module.
336- compiler_specs (List[CompileSpec]): Compiler specs for Qualcomm AI Engine Direct.
337- constant_methods (Optional[Dict[str, Any]]): An optional dictionary of method name to the constant value
338- returned by that method in eager mode. Often used to store config information on
339- Edge models.
340- dynamic_shapes (Optional[Dict]): Information about dynamic shapes.
341- dep_table (Optional[Dict]): Dependency table for the transformation passes.
342- passes_job (Optional[OrderedDict]): Ordered dictionary of transformation passes.
343- skip_node_id_set (Optional[set]): Set of node IDs to skip during partitioning.
344- skip_node_op_set (Optional[set]): Set of node operations to skip during partitioning.
339+ module (Union[torch.nn.Module, torch.fx.GraphModule,Dict[str, torch.nn.Module], Dict[str, torch.fx.GraphModule]]):
340+ The PyTorch module or fx.GraphModule to be transformed.
341+ inputs (Union[Tuple[torch.Tensor], Dict[str, Tuple[torch.Tensor]]]):
342+ The input tensors for the module.
343+ compiler_specs (Union[List[Any], Dict[str, List[Any]]]):
344+ Compiler specifications for Qualcomm AI Engine Direct.
345+ constant_methods (Optional[Dict[str, Any]]):
346+ An optional dictionary mapping method names to constant values returned by those methods in eager mode.
347+ Often used to store configuration information on Edge models.
348+ dynamic_shapes (Optional[Dict]):
349+ Information about dynamic shapes.
350+ dep_table (Optional[Dict]):
351+ Dependency table for the transformation passes.
352+ passes_job (Optional[Union[OrderedDict, Dict[str, OrderedDict]]]):
353+ Ordered dictionary of transformation passes.
354+ skip_node_id_set (Optional[set]):
355+ Set of node IDs to skip during partitioning.
356+ skip_node_op_set (Optional[set]):
357+ Set of node operations to skip during partitioning.
345358
346359 Returns:
347- EdgeProgramManager: The manager for the edge program after transformation and lowering.
360+ EdgeProgramManager:
361+ The manager for the edge program after transformation and lowering.
348362 """
349- ep = torch .export .export (module , inputs , dynamic_shapes = dynamic_shapes , strict = True )
350- # This transformation is primarily intended for the LiftConstantScalarOperands pass
351- # to avoid creating temporary tensors in the operation builder.
352- # However, this pass will create a get_attr node, which should be converted
353- # into a lifted tensor constant by the lift_constant_tensor_pass.
354- # If placed in the to_edge_transform_passes, it will be executed
355- # after the lift_constant_tensor_pass, causing the operation builder
356- # to fail to correctly retrieve the parameter by the get_parameter.
357- ep = QnnPassManager ().transform_for_export_pipeline (ep )
358- transform_passes = QnnPassManager ().get_to_edge_transform_passes (
359- ep , passes_job = passes_job , dep_table = dep_table
360- )
361- qnn_partitioner = QnnPartitioner (
362- compiler_specs ,
363- skip_node_id_set = skip_node_id_set ,
364- skip_node_op_set = skip_node_op_set ,
365- )
366- edge_program_manager = to_edge_transform_and_lower (
367- ep ,
363+
364+ def ensure_graph_specific_dict (value , graph_names , callback = None ):
365+ """
366+ Ensures the input value is a dictionary with keys matching the provided graph names.
367+ If the input is not a dictionary or its keys do not match the graph names, a new dictionary
368+ is created with the graph names as keys and the input value assigned to each key.
369+
370+ Examples:
371+ 1. Input is None:
372+ >>> ensure_graph_specific_dict(None, ["forward1", "forward2"])
373+ {'forward1': None, 'forward2': None}
374+
375+ 2. Input is a single value:
376+ >>> ensure_graph_specific_dict(input, ["forward1", "forward2"])
377+ {'forward1': input, 'forward2': input}
378+
379+ 3. Input is a non-graph specific dict:
380+ >>> ensure_graph_specific_dict({Any: input}, ["forward1", "forward2"])
381+ {'forward1': {Any: input}, 'forward2': {Any: input}}
382+ """
383+ if value is None :
384+ return {graph_name : None for graph_name in graph_names }
385+ if isinstance (value , dict ) and graph_names == value .keys ():
386+ return value
387+ return {graph_name : value for graph_name in graph_names }
388+
389+ if not isinstance (module , dict ):
390+ module = {"forward" : module }
391+
392+ # Ensure attributes are graph-specific dictionaries
393+ graph_names = module .keys ()
394+ inputs = ensure_graph_specific_dict (inputs , graph_names )
395+ compiler_specs = ensure_graph_specific_dict (compiler_specs , graph_names )
396+ dynamic_shapes = ensure_graph_specific_dict (dynamic_shapes , graph_names )
397+ dep_table = ensure_graph_specific_dict (dep_table , graph_names )
398+ passes_job = ensure_graph_specific_dict (passes_job , graph_names )
399+
400+ # Prepare programs and partitioners
401+ aten_programs = {}
402+ transform_passes = {}
403+ qnn_partitioners = {
404+ graph_name : [
405+ QnnPartitioner (
406+ compiler_specs [graph_name ],
407+ skip_node_id_set = skip_node_id_set ,
408+ skip_node_op_set = skip_node_op_set ,
409+ )
410+ ]
411+ for graph_name in graph_names
412+ }
413+
414+ for graph_name , m in module .items ():
415+ ep = torch .export .export (
416+ m ,
417+ inputs [graph_name ],
418+ dynamic_shapes = dynamic_shapes [graph_name ],
419+ strict = True ,
420+ )
421+ # This transformation is primarily intended for the LiftConstantScalarOperands pass
422+ # to avoid creating temporary tensors in the operation builder.
423+ # However, this pass will create a get_attr node, which should be converted
424+ # into a lifted tensor constant by the lift_constant_tensor_pass.
425+ # If placed in the to_edge_transform_passes, it will be executed
426+ # after the lift_constant_tensor_pass, causing the operation builder
427+ # to fail to correctly retrieve the parameter by the get_parameter.
428+ aten_programs [graph_name ] = QnnPassManager ().transform_for_export_pipeline (ep )
429+ transform_passes [graph_name ] = QnnPassManager ().get_to_edge_transform_passes (
430+ ep , passes_job = passes_job [graph_name ], dep_table = dep_table [graph_name ]
431+ )
432+
433+ return to_edge_transform_and_lower (
434+ aten_programs ,
368435 transform_passes = transform_passes ,
369- partitioner = [ qnn_partitioner ] ,
436+ partitioner = qnn_partitioners ,
370437 constant_methods = constant_methods ,
371438 compile_config = qnn_edge_config (),
372439 )
373- return edge_program_manager
374440
375441
376442def capture_program (
0 commit comments