| 
 | 1 | +# Copyright (c) Qualcomm Innovation Center, Inc.  | 
 | 2 | +# All rights reserved  | 
 | 3 | +#  | 
 | 4 | +# This source code is licensed under the BSD-style license found in the  | 
 | 5 | +# LICENSE file in the root directory of this source tree.  | 
 | 6 | + | 
 | 7 | +from typing import Dict, List  | 
 | 8 | + | 
 | 9 | +import torch  | 
 | 10 | +from executorch.backends.qualcomm.serialization.qc_schema import (  | 
 | 11 | +    QnnExecuTorchOpPackageInfo,  | 
 | 12 | +)  | 
 | 13 | + | 
 | 14 | +from .node_visitor import NodeVisitor  | 
 | 15 | +from .op_custom_op import CustomOp  | 
 | 16 | +from .utils import is_graph_input, is_graph_output  | 
 | 17 | + | 
 | 18 | + | 
 | 19 | +# This will hold mapping of all node names to the visitor class  | 
 | 20 | +_node_visitor_dict = {}  | 
 | 21 | + | 
 | 22 | + | 
 | 23 | +def register_node_visitor(visitor):  | 
 | 24 | +    """Register node visitor into _node_visitor_dict"""  | 
 | 25 | +    assert (  | 
 | 26 | +        isinstance(visitor, type)  | 
 | 27 | +        and issubclass(visitor, NodeVisitor)  | 
 | 28 | +        and hasattr(visitor, "target")  | 
 | 29 | +    ), f"Informed NodeVisitor subclass, can't register!, got: {visitor}"  | 
 | 30 | +    for target in visitor.target:  | 
 | 31 | +        _node_visitor_dict[target] = visitor  | 
 | 32 | + | 
 | 33 | + | 
 | 34 | +def generate_node_to_external_map(  | 
 | 35 | +    edge_program: torch.export.ExportedProgram,  | 
 | 36 | +) -> Dict[torch.fx.Node, int]:  | 
 | 37 | +    node_to_external_map = {}  | 
 | 38 | +    for node in edge_program.graph_module.graph.nodes:  | 
 | 39 | +        # The order in which we visit the placeholder node is same as the *args  | 
 | 40 | +        # order for the forward(*args) signature for this gm. Using the order of  | 
 | 41 | +        # the nodes as external_id to extract the right arg from *args at runtime  | 
 | 42 | +        if is_graph_input(node, edge_program):  | 
 | 43 | +            node_to_external_map[node] = len(node_to_external_map)  | 
 | 44 | +    for node in edge_program.graph_module.graph.nodes:  | 
 | 45 | +        if is_graph_output(node):  | 
 | 46 | +            node_to_external_map[node] = len(node_to_external_map)  | 
 | 47 | +    return node_to_external_map  | 
 | 48 | + | 
 | 49 | + | 
 | 50 | +def get_node_visitors(  | 
 | 51 | +    edge_program: torch.export.ExportedProgram,  | 
 | 52 | +    enable_tensor_dump=False,  | 
 | 53 | +    op_package_infos: List[QnnExecuTorchOpPackageInfo] = None,  | 
 | 54 | +) -> Dict[str, NodeVisitor]:  | 
 | 55 | +    """Create a new class instance at runtime, and put them in a dict"""  | 
 | 56 | +    node_to_external_map = generate_node_to_external_map(edge_program)  | 
 | 57 | +    node_visitors = {}  | 
 | 58 | +    for target, visitor in _node_visitor_dict.items():  | 
 | 59 | +        assert callable(  | 
 | 60 | +            visitor  | 
 | 61 | +        ), f"Expecting a callable class, but got {visitor} of type {type(visitor)}"  | 
 | 62 | +        node_visitors[target] = visitor(  | 
 | 63 | +            node_to_external_map, edge_program, enable_tensor_dump  | 
 | 64 | +        )  | 
 | 65 | +    if op_package_infos:  | 
 | 66 | +        custom_ops = []  | 
 | 67 | +        for op_package_info in op_package_infos:  | 
 | 68 | +            if op_package_info.custom_op_name not in custom_ops:  | 
 | 69 | +                custom_op_builder = CustomOp(  | 
 | 70 | +                    op_package_info,  | 
 | 71 | +                    node_to_external_map,  | 
 | 72 | +                    edge_program,  | 
 | 73 | +                    enable_tensor_dump,  | 
 | 74 | +                )  | 
 | 75 | +                node_visitors[op_package_info.custom_op_name] = custom_op_builder  | 
 | 76 | +                custom_ops.append(op_package_info.custom_op_name)  | 
 | 77 | +    return node_visitors  | 
0 commit comments