|  | 
| 5 | 5 | # LICENSE file in the root directory of this source tree. | 
| 6 | 6 | import operator | 
| 7 | 7 | import warnings | 
| 8 |  | -from collections import OrderedDict | 
|  | 8 | +from collections import defaultdict, OrderedDict | 
| 9 | 9 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union | 
| 10 | 10 | 
 | 
| 11 | 11 | import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor | 
| @@ -1038,3 +1038,53 @@ def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable): | 
| 1038 | 1038 |     for node in gm.graph.nodes: | 
| 1039 | 1039 |         if dtype := get_quant_io_dtype_fn(node): | 
| 1040 | 1040 |             node.meta[QCOM_QUANTIZED_IO] = dtype | 
|  | 1041 | + | 
|  | 1042 | + | 
|  | 1043 | +def rewrite_prepared_observer( | 
|  | 1044 | +    graph_module: torch.fx.GraphModule, name_obs_dict: Dict[str, torch.nn.Module] | 
|  | 1045 | +): | 
|  | 1046 | +    """ | 
|  | 1047 | +    Rewrite the observer of the specified observer module name in the graph_module. | 
|  | 1048 | +
 | 
|  | 1049 | +    Example: | 
|  | 1050 | +    Consider the following graph_module after prepare_pt2e: | 
|  | 1051 | +    gm = prepare_pt2e(gm) | 
|  | 1052 | +    print(gm) | 
|  | 1053 | +
 | 
|  | 1054 | +    GraphModule( | 
|  | 1055 | +      (activation_post_process_0): MinMaxObserver(min_val=inf, max_val=-inf) | 
|  | 1056 | +      (activation_post_process_1): MinMaxObserver(min_val=inf, max_val=-inf) | 
|  | 1057 | +      (activation_post_process_2): MinMaxObserver(min_val=inf, max_val=-inf) | 
|  | 1058 | +      (activation_post_process_3): MinMaxObserver(min_val=inf, max_val=-inf) | 
|  | 1059 | +    ) | 
|  | 1060 | +
 | 
|  | 1061 | +    new_observer = observer.FixedQParamsObserver( | 
|  | 1062 | +        scale=0.125, | 
|  | 1063 | +        zero_point=42, | 
|  | 1064 | +        dtype=torch.uint8, | 
|  | 1065 | +        quant_min=0, | 
|  | 1066 | +        quant_max=255, | 
|  | 1067 | +        qscheme=torch.per_tensor_affine, | 
|  | 1068 | +    ) | 
|  | 1069 | +
 | 
|  | 1070 | +    Calling rewrite_prepared_observer(gm, {"activation_post_process_0": new_observer}) | 
|  | 1071 | +    is equivalent to: | 
|  | 1072 | +    gm.activation_post_process_0 = new_observer | 
|  | 1073 | +
 | 
|  | 1074 | +    Note: | 
|  | 1075 | +    If the rewritten observer is a SharedQuantizationSpec, all other shared observers will also be rewritten. | 
|  | 1076 | +    """ | 
|  | 1077 | +    module_name_list = defaultdict(list) | 
|  | 1078 | +    for name, module in graph_module.named_modules(remove_duplicate=False): | 
|  | 1079 | +        module_name_list[module].append(name) | 
|  | 1080 | + | 
|  | 1081 | +    for name, new_observer in name_obs_dict.items(): | 
|  | 1082 | +        old_module = getattr(graph_module, name, None) | 
|  | 1083 | + | 
|  | 1084 | +        if not old_module: | 
|  | 1085 | +            print( | 
|  | 1086 | +                f"[WARNING], No observer named as {name} found, please check the moudle name" | 
|  | 1087 | +            ) | 
|  | 1088 | +            continue | 
|  | 1089 | +        for target_name in module_name_list[old_module]: | 
|  | 1090 | +            setattr(graph_module, target_name, new_observer) | 
0 commit comments