@@ -67,7 +67,7 @@ def partition_types(self) -> list[OpOverload]:
6767 @abstractmethod
6868 def get_anchors (
6969 self , gm : torch .fx .GraphModule , fused_partition : List [fx .GraphModule ]
70- ) -> Optional [PartitionAnchors ]:
70+ ) -> Tuple [PartitionAnchors , fx . Node ]:
7171 pass
7272
7373 @abstractmethod
@@ -85,7 +85,7 @@ def partition_types(self) -> List[OpOverload]:
8585
8686 def get_anchors (
8787 self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
88- ) -> PartitionAnchors :
88+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
8989 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
9090 addmm_node = fused_partition [0 ].nodes [- 1 ]
9191
@@ -101,12 +101,12 @@ def get_anchors(
101101 qscheme = torch .per_tensor_affine ,
102102 )
103103
104- return PartitionAnchors (
104+ return ( PartitionAnchors (
105105 inputs = [(addmm_node , 1 )],
106106 weights = [(addmm_node , 2 )],
107107 biases = [(addmm_node , 0 , bias_qspec )],
108108 output = [(addmm_node ,)],
109- )
109+ ), addmm_node )
110110
111111 def replacement_op (self ) -> OpOverload :
112112 return torch .ops .cadence .quantized_linear .default
@@ -118,7 +118,7 @@ def partition_types(self) -> List[OpOverload]:
118118
119119 def get_anchors (
120120 self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
121- ) -> PartitionAnchors :
121+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
122122 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
123123 add_node = fused_partition [0 ].nodes [- 1 ]
124124
@@ -129,16 +129,16 @@ def get_anchors(
129129 add_node .args [1 ], fx .Node
130130 )
131131 if not is_tensor_add or len (add_node .kwargs ) > 0 :
132- return PartitionAnchors (
132+ return ( PartitionAnchors (
133133 empty = True ,
134- )
134+ ), add_node )
135135
136- return PartitionAnchors (
136+ return ( PartitionAnchors (
137137 inputs = [(add_node , 0 ), (add_node , 1 )],
138138 weights = [],
139139 biases = [],
140140 output = [(add_node ,)],
141- )
141+ ), add_node )
142142
143143 def replacement_op (self ) -> OpOverload :
144144 return torch .ops .cadence .quantized_add .default
@@ -150,16 +150,16 @@ def partition_types(self) -> List[OpOverload]:
150150
151151 def get_anchors (
152152 self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
153- ) -> PartitionAnchors :
153+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
154154 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
155155 bmm_node = fused_partition [0 ].nodes [- 1 ]
156156
157- return PartitionAnchors (
157+ return ( PartitionAnchors (
158158 inputs = [(bmm_node , 0 ), (bmm_node , 1 )],
159159 weights = [],
160160 biases = [],
161161 output = [(bmm_node ,)],
162- )
162+ ), bmm_node )
163163
164164 def replacement_op (self ) -> OpOverload :
165165 return torch .ops .cadence .quantized_matmul .default
@@ -171,7 +171,7 @@ def partition_types(self) -> List[OpOverload]:
171171
172172 def get_anchors (
173173 self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
174- ) -> PartitionAnchors :
174+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
175175 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
176176 cat_node = fused_partition [0 ].nodes [- 1 ]
177177
@@ -198,14 +198,14 @@ def get_anchors(
198198 )
199199 )
200200
201- return PartitionAnchors (
201+ return ( PartitionAnchors (
202202 inputs = args ,
203203 weights = [],
204204 biases = [],
205205 output = [
206206 (cat_node , SharedQuantizationSpec ((cat_node .args [0 ][0 ], cat_node )))
207207 ],
208- )
208+ ), cat_node )
209209
210210 def replacement_op (self ) -> OpOverload :
211211 return torch .ops .aten .cat .default
@@ -217,7 +217,7 @@ def partition_types(self) -> List[OpOverload]:
217217
218218 def get_anchors (
219219 self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
220- ) -> PartitionAnchors :
220+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
221221 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
222222 conv1d_node = fused_partition [0 ].nodes [- 1 ]
223223
@@ -238,13 +238,13 @@ def get_anchors(
238238 if len (conv1d_node .args ) > 2 and conv1d_node .args [2 ] is not None :
239239 bias = [(conv1d_node , 2 , bias_qspec )]
240240
241- return PartitionAnchors (
241+ return ( PartitionAnchors (
242242 inputs = [(conv1d_node , 0 )],
243243 weights = [(conv1d_node , 1 )],
244244 # pyre-fixme[6]: Incompatible parameter type
245245 biases = bias ,
246246 output = [(conv1d_node ,)],
247- )
247+ ), conv1d_node )
248248
249249 def replacement_op (self ) -> OpOverload :
250250 return torch .ops .cadence .quantized_conv2d_nchw .default
@@ -256,7 +256,7 @@ def partition_types(self) -> List[OpOverload]:
256256
257257 def get_anchors (
258258 self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
259- ) -> PartitionAnchors :
259+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
260260 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
261261 conv2d_node = fused_partition [0 ].nodes [- 1 ]
262262
@@ -277,13 +277,13 @@ def get_anchors(
277277 if len (conv2d_node .args ) > 2 and conv2d_node .args [2 ] is not None :
278278 bias = [(conv2d_node , 2 , bias_qspec )]
279279
280- return PartitionAnchors (
280+ return ( PartitionAnchors (
281281 inputs = [(conv2d_node , 0 )],
282282 weights = [(conv2d_node , 1 )],
283283 # pyre-fixme[6]: Incompatible parameter type
284284 biases = bias ,
285285 output = [(conv2d_node ,)],
286- )
286+ ), conv2d_node )
287287
288288 def replacement_op (self ) -> OpOverload :
289289 return torch .ops .cadence .quantized_conv2d_nchw .default
@@ -295,7 +295,7 @@ def partition_types(self) -> List[OpOverload]:
295295
296296 def get_anchors (
297297 self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
298- ) -> PartitionAnchors :
298+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
299299 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
300300 layer_norm_node = fused_partition [0 ].nodes [- 1 ]
301301
@@ -311,14 +311,14 @@ def get_anchors(
311311
312312 # Weights are used in quantized mode by our kernel, so they are
313313 # passed in as others here along with the normalized shape.
314- return PartitionAnchors (
314+ return ( PartitionAnchors (
315315 inputs = [(layer_norm_node , 0 )],
316316 weights = [],
317317 biases = [],
318318 # Ordering: normalized_shape, weights, bias
319319 others = others ,
320320 output = [(layer_norm_node ,)],
321- )
321+ ), layer_norm_node )
322322
323323 def replacement_op (self ) -> OpOverload :
324324 return torch .ops .cadence .quantized_layer_norm .default
@@ -330,7 +330,7 @@ def partition_types(self) -> List[OpOverload]:
330330
331331 def get_anchors (
332332 self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
333- ) -> PartitionAnchors :
333+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
334334 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
335335 linear_node = fused_partition [0 ].nodes [- 1 ]
336336
@@ -351,13 +351,13 @@ def get_anchors(
351351 if len (linear_node .args ) > 2 :
352352 bias = [(linear_node , 2 , bias_qspec )]
353353
354- return PartitionAnchors (
354+ return ( PartitionAnchors (
355355 inputs = [(linear_node , 0 )],
356356 weights = [(linear_node , 1 )],
357357 # pyre-fixme[6]: Incompatible parameter type
358358 biases = bias ,
359359 output = [(linear_node ,)],
360- )
360+ ), linear_node )
361361
362362 def replacement_op (self ) -> OpOverload :
363363 return torch .ops .cadence .quantized_linear .default
@@ -369,16 +369,16 @@ def partition_types(self) -> List[OpOverload]:
369369
370370 def get_anchors (
371371 self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
372- ) -> PartitionAnchors :
372+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
373373 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
374374 matmul_node = fused_partition [0 ].nodes [- 1 ]
375375
376- return PartitionAnchors (
376+ return ( PartitionAnchors (
377377 inputs = [(matmul_node , 0 ), (matmul_node , 1 )],
378378 weights = [],
379379 biases = [],
380380 output = [(matmul_node ,)],
381- )
381+ ), matmul_node )
382382
383383 def replacement_op (self ) -> OpOverload :
384384 return torch .ops .cadence .quantized_matmul .default
@@ -392,16 +392,16 @@ def partition_types(self) -> List[OpOverload]:
392392
393393 def get_anchors (
394394 self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
395- ) -> PartitionAnchors :
395+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
396396 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
397397 relu_node = fused_partition [0 ].nodes [- 1 ]
398398
399- return PartitionAnchors (
399+ return ( PartitionAnchors (
400400 inputs = [(relu_node , 0 )],
401401 weights = [],
402402 biases = [],
403403 output = [(relu_node ,)],
404- )
404+ ), relu_node )
405405
406406 def replacement_op (self ) -> OpOverload :
407407 return torch .ops .cadence .quantized_relu .default
@@ -427,7 +427,7 @@ def partition_types(self) -> List[OpOverload]:
427427
428428 def get_anchors (
429429 self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
430- ) -> PartitionAnchors :
430+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
431431 # The first node should be conv, the second should be relu
432432 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
433433 conv_node = fused_partition [0 ].nodes [- 1 ] # Second to last node
@@ -451,13 +451,13 @@ def get_anchors(
451451 if len (conv_node .args ) > 2 and conv_node .args [2 ] is not None :
452452 bias = [(conv_node , 2 , bias_qspec )]
453453
454- return PartitionAnchors (
454+ return ( PartitionAnchors (
455455 inputs = [(conv_node , 0 )],
456456 weights = [(conv_node , 1 )],
457457 # pyre-fixme[6]: Incompatible parameter type
458458 biases = bias ,
459459 output = [(relu_node ,)], # Output is from the relu node
460- )
460+ ), relu_node )
461461
462462 def replacement_op (self ) -> OpOverload :
463463 return torch .ops .cadence .quantized_conv2d_nchw .default
@@ -498,12 +498,12 @@ def get_anchors(
498498 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
499499 softmax_node = fused_partition [0 ].nodes [- 1 ]
500500
501- return PartitionAnchors (
501+ return ( PartitionAnchors (
502502 inputs = [(softmax_node , 0 )],
503503 weights = [],
504504 biases = [],
505505 output = [(softmax_node ,)],
506- )
506+ ), softmax_node )
507507
508508 def replacement_op (self ) -> OpOverload :
509509 return torch .ops .cadence .quantized_softmax .default
0 commit comments