@@ -89,6 +89,41 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
8989 _annotate_output_qspec (node , quant_property .qspec )
9090
9191
92+ def _match_pattern (
93+ node : Node , pattern : List [List ], filter_fn : Optional [Callable [[Node ], bool ]] = None
94+ ) -> bool :
95+ """
96+ Check if there's a chain of node.ancestors? -> node -> node.descendant? that matches the
97+ chain provided in 'pattern'. If 'filter_fn' is provided, check that all the nodes in the
98+ chain pass the filtering.
99+
100+ Each 'pattern' element is composed of a list of disjunctive nodes types.
101+ """
102+ assert len (pattern ) == 2 , "Only two-nodes patterns supported currently"
103+
104+ if node .target in pattern [0 ]:
105+ assert len (node .users ) != 0
106+ parent = node
107+ child = next (iter (node .users ))
108+ elif node .target in pattern [1 ]:
109+ assert len (node .args ) != 0
110+ parent = node .args [0 ]
111+ child = node
112+ else :
113+ return False
114+
115+ if len (parent .users ) != 1 :
116+ return False
117+
118+ if parent .target not in pattern [0 ] or child .target not in pattern [1 ]:
119+ return False
120+
121+ if filter_fn is not None :
122+ return filter_fn (parent ) and filter_fn (child )
123+
124+ return True
125+
126+
92127_one_to_one = [
93128 torch .ops .aten .exp .default ,
94129 torch .ops .aten .log .default ,
@@ -164,7 +199,36 @@ def get_quant_properties( # noqa: C901
164199 bias_qspec = quantization_config .get_bias_qspec ()
165200
166201 quant_properties = _OpQuantProperties ()
167- if node .target in (
202+
203+ def any_or_hardtanh_min_zero (n : Node ):
204+ # Check that if the node is a hardtanh, its min_val is zero
205+ return n .target != torch .ops .aten .hardtanh .default or n .args [1 ] == 0
206+
207+ if _match_pattern (
208+ node ,
209+ [
210+ [
211+ torch .ops .aten .conv1d .default ,
212+ torch .ops .aten .conv2d .default ,
213+ torch .ops .aten .linear .default ,
214+ ],
215+ [torch .ops .aten .relu .default , torch .ops .aten .hardtanh .default ],
216+ ],
217+ any_or_hardtanh_min_zero ,
218+ ):
219+ if node .target in (
220+ torch .ops .aten .conv1d .default ,
221+ torch .ops .aten .conv2d .default ,
222+ torch .ops .aten .linear .default ,
223+ ):
224+ quant_properties .quant_inputs = [
225+ _QuantProperty (0 , input_act_qspec ),
226+ _QuantProperty (1 , weight_qspec , mark_annotated = True ),
227+ _QuantProperty (2 , bias_qspec , optional = True , mark_annotated = True ),
228+ ]
229+ else :
230+ quant_properties .quant_output = _QuantProperty (0 , output_act_qspec )
231+ elif node .target in (
168232 torch .ops .aten .conv1d .default ,
169233 torch .ops .aten .conv2d .default ,
170234 torch .ops .aten .linear .default ,
0 commit comments