Skip to content

Commit 1fcbaf1

Browse files
billmguofacebook-github-bot
authored andcommitted
add custom annoatation for new model export
Summary: add custom annoatation for new model export Differential Revision: D77569950
1 parent 75e4044 commit 1fcbaf1

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,41 @@ def annotate_single_in_single_out(
227227
_annotated=True,
228228
)
229229

230+
def annotate_single_in_share_out(
231+
node: Node, quantization_config: QuantizationConfig
232+
) -> None:
233+
234+
input_qspec_map = {}
235+
input_act = node.args[0]
236+
input_qspec_map[input_act] = quantization_config.input_activation
237+
238+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
239+
input_qspec_map=input_qspec_map,
240+
output_qspec=SharedQuantizationSpec((input_act, node)),
241+
_annotated=True,
242+
)
243+
244+
def annotate_stack(node: Node, quantization_config: QuantizationConfig) -> None:
245+
input_nodes = node.args[0]
246+
247+
first_input_node = input_nodes[0]
248+
input_qspec_map = {}
249+
assert isinstance(first_input_node, Node)
250+
input_qspec_map[first_input_node] = quantization_config.input_activation
251+
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
252+
(first_input_node, node)
253+
)
254+
255+
for input_node in input_nodes[1:]:
256+
if input_node not in input_qspec_map:
257+
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
258+
259+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
260+
input_qspec_map=input_qspec_map,
261+
output_qspec=share_qparams_with_input_act0_qspec,
262+
_annotated=True,
263+
)
264+
230265
def annotate_matmul_input1(node: Node):
231266
quantization_config_8a8w = get_8a8w_qnn_ptq_config(
232267
act_symmetric=True, act_observer=MinMaxObserver
@@ -247,6 +282,12 @@ def annotate_matmul_input1(node: Node):
247282
]:
248283
annotate_single_in_single_out(node, quantization_config_8a8w)
249284
node = node.args[0]
285+
elif node.target == torch.ops.aten.stack.default:
286+
annotate_stack(node, quantization_config_8a8w)
287+
node = node.args[0]
288+
elif node.target == torch.ops.aten.flatten.using_ints:
289+
annotate_single_in_share_out(node, quantization_config_8a8w)
290+
node = node.args[0]
250291
elif node.target == torch.ops.aten.cat.default:
251292
annotate_cat(node, quantization_config_8a8w)
252293
# For v, we tag 8a until conv op.

0 commit comments

Comments
 (0)