|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 | import operator
|
7 | 7 | import warnings
|
8 |
| -from collections import OrderedDict |
| 8 | +from collections import defaultdict, OrderedDict |
9 | 9 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
10 | 10 |
|
11 | 11 | import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor
|
@@ -1038,3 +1038,53 @@ def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable):
|
1038 | 1038 | for node in gm.graph.nodes:
|
1039 | 1039 | if dtype := get_quant_io_dtype_fn(node):
|
1040 | 1040 | node.meta[QCOM_QUANTIZED_IO] = dtype
|
| 1041 | + |
| 1042 | + |
| 1043 | +def rewrite_prepared_observer( |
| 1044 | + graph_module: torch.fx.GraphModule, name_obs_dict: Dict[str, torch.nn.Module] |
| 1045 | +): |
| 1046 | + """ |
| 1047 | + Rewrite the observer of the specified observer module name in the graph_module. |
| 1048 | +
|
| 1049 | + Example: |
| 1050 | + Consider the following graph_module after prepare_pt2e: |
| 1051 | + gm = prepare_pt2e(gm) |
| 1052 | + print(gm) |
| 1053 | +
|
| 1054 | + GraphModule( |
| 1055 | + (activation_post_process_0): MinMaxObserver(min_val=inf, max_val=-inf) |
| 1056 | + (activation_post_process_1): MinMaxObserver(min_val=inf, max_val=-inf) |
| 1057 | + (activation_post_process_2): MinMaxObserver(min_val=inf, max_val=-inf) |
| 1058 | + (activation_post_process_3): MinMaxObserver(min_val=inf, max_val=-inf) |
| 1059 | + ) |
| 1060 | +
|
| 1061 | + new_observer = observer.FixedQParamsObserver( |
| 1062 | + scale=0.125, |
| 1063 | + zero_point=42, |
| 1064 | + dtype=torch.uint8, |
| 1065 | + quant_min=0, |
| 1066 | + quant_max=255, |
| 1067 | + qscheme=torch.per_tensor_affine, |
| 1068 | + ) |
| 1069 | +
|
| 1070 | + Calling rewrite_prepared_observer(gm, {"activation_post_process_0": new_observer}) |
| 1071 | + is equivalent to: |
| 1072 | + gm.activation_post_process_0 = new_observer |
| 1073 | +
|
| 1074 | + Note: |
| 1075 | + If the rewritten observer is a SharedQuantizationSpec, all other shared observers will also be rewritten. |
| 1076 | + """ |
| 1077 | + module_name_list = defaultdict(list) |
| 1078 | + for name, module in graph_module.named_modules(remove_duplicate=False): |
| 1079 | + module_name_list[module].append(name) |
| 1080 | + |
| 1081 | + for name, new_observer in name_obs_dict.items(): |
| 1082 | + old_module = getattr(graph_module, name, None) |
| 1083 | + |
| 1084 | + if not old_module: |
| 1085 | + print( |
| 1086 | + f"[WARNING], No observer named as {name} found, please check the moudle name" |
| 1087 | + ) |
| 1088 | + continue |
| 1089 | + for target_name in module_name_list[old_module]: |
| 1090 | + setattr(graph_module, target_name, new_observer) |
0 commit comments