|  | 
|  | 1 | +# Copyright (c) Qualcomm Innovation Center, Inc. | 
|  | 2 | +# All rights reserved | 
|  | 3 | +# | 
|  | 4 | +# This source code is licensed under the BSD-style license found in the | 
|  | 5 | +# LICENSE file in the root directory of this source tree. | 
|  | 6 | + | 
|  | 7 | +from typing import Dict, List | 
|  | 8 | + | 
|  | 9 | +import torch | 
|  | 10 | +from executorch.backends.qualcomm.serialization.qc_schema import ( | 
|  | 11 | +    QnnExecuTorchOpPackageInfo, | 
|  | 12 | +) | 
|  | 13 | + | 
|  | 14 | +from .node_visitor import NodeVisitor | 
|  | 15 | +from .op_custom_op import CustomOp | 
|  | 16 | +from .utils import is_graph_input, is_graph_output | 
|  | 17 | + | 
|  | 18 | + | 
|  | 19 | +# This will hold mapping of all node names to the visitor class | 
|  | 20 | +_node_visitor_dict = {} | 
|  | 21 | + | 
|  | 22 | + | 
|  | 23 | +def register_node_visitor(visitor): | 
|  | 24 | +    """Register node visitor into _node_visitor_dict""" | 
|  | 25 | +    assert ( | 
|  | 26 | +        isinstance(visitor, type) | 
|  | 27 | +        and issubclass(visitor, NodeVisitor) | 
|  | 28 | +        and hasattr(visitor, "target") | 
|  | 29 | +    ), f"Informed NodeVisitor subclass, can't register!, got: {visitor}" | 
|  | 30 | +    for target in visitor.target: | 
|  | 31 | +        _node_visitor_dict[target] = visitor | 
|  | 32 | + | 
|  | 33 | + | 
|  | 34 | +def generate_node_to_external_map( | 
|  | 35 | +    edge_program: torch.export.ExportedProgram, | 
|  | 36 | +) -> Dict[torch.fx.Node, int]: | 
|  | 37 | +    node_to_external_map = {} | 
|  | 38 | +    for node in edge_program.graph_module.graph.nodes: | 
|  | 39 | +        # The order in which we visit the placeholder node is same as the *args | 
|  | 40 | +        # order for the forward(*args) signature for this gm. Using the order of | 
|  | 41 | +        # the nodes as external_id to extract the right arg from *args at runtime | 
|  | 42 | +        if is_graph_input(node, edge_program): | 
|  | 43 | +            node_to_external_map[node] = len(node_to_external_map) | 
|  | 44 | +    for node in edge_program.graph_module.graph.nodes: | 
|  | 45 | +        if is_graph_output(node): | 
|  | 46 | +            node_to_external_map[node] = len(node_to_external_map) | 
|  | 47 | +    return node_to_external_map | 
|  | 48 | + | 
|  | 49 | + | 
|  | 50 | +def get_node_visitors( | 
|  | 51 | +    edge_program: torch.export.ExportedProgram, | 
|  | 52 | +    enable_tensor_dump=False, | 
|  | 53 | +    op_package_infos: List[QnnExecuTorchOpPackageInfo] = None, | 
|  | 54 | +) -> Dict[str, NodeVisitor]: | 
|  | 55 | +    """Create a new class instance at runtime, and put them in a dict""" | 
|  | 56 | +    node_to_external_map = generate_node_to_external_map(edge_program) | 
|  | 57 | +    node_visitors = {} | 
|  | 58 | +    for target, visitor in _node_visitor_dict.items(): | 
|  | 59 | +        assert callable( | 
|  | 60 | +            visitor | 
|  | 61 | +        ), f"Expecting a callable class, but got {visitor} of type {type(visitor)}" | 
|  | 62 | +        node_visitors[target] = visitor( | 
|  | 63 | +            node_to_external_map, edge_program, enable_tensor_dump | 
|  | 64 | +        ) | 
|  | 65 | +    if op_package_infos: | 
|  | 66 | +        custom_ops = [] | 
|  | 67 | +        for op_package_info in op_package_infos: | 
|  | 68 | +            if op_package_info.custom_op_name not in custom_ops: | 
|  | 69 | +                custom_op_builder = CustomOp( | 
|  | 70 | +                    op_package_info, | 
|  | 71 | +                    node_to_external_map, | 
|  | 72 | +                    edge_program, | 
|  | 73 | +                    enable_tensor_dump, | 
|  | 74 | +                ) | 
|  | 75 | +                node_visitors[op_package_info.custom_op_name] = custom_op_builder | 
|  | 76 | +                custom_ops.append(op_package_info.custom_op_name) | 
|  | 77 | +    return node_visitors | 
0 commit comments