2121 is_dynamic_qdq ,
2222 is_per_channel ,
2323 is_per_channel_group ,
24+ is_per_tensor ,
2425 is_qparam ,
2526 is_quant ,
2627)
@@ -66,8 +67,6 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
6667 return False
6768
6869 is_valid , _ = self .get_deps (node , ep )
69- if not is_valid :
70- why (node , "Failed to get valid dependent nodes." )
7170 return is_valid
7271
7372 def get_node_and_deps (
@@ -123,6 +122,7 @@ def get_deps(
123122 precision = self ._detect_precision (node )
124123 if precision not in self .supported_precision_types ():
125124 # detected precision but it is either disabled or not supported
125+ why (node , f"Unsupported precision type { precision } " )
126126 return (False , [])
127127 _ , precision = self ._overwrite_precision (node )
128128 valid_bias , bias_deps = self ._get_bias_deps (node , ep , precision )
@@ -143,26 +143,34 @@ def _get_weight_deps(
143143 # First find the weight
144144 weight_node = get_input_node (node , self .weight_idx )
145145 if not is_param_node (ep , weight_node ):
146- return (False , []) # weight must be a static param
146+ why (node , "Expected weight to be a static param" )
147+ return (False , [])
147148 gemm_deps .append (weight_node )
148149
149150 return (True , gemm_deps )
150151 else :
151152 # Quantized Weight deps
152153 dequant_node = get_input_node (node , self .weight_idx )
153154 if not is_dequant (dequant_node ):
155+ why (node , "Expected weight to have a dequantized node" )
154156 return False , []
155157 gemm_deps .append (dequant_node )
156158 weight = get_input_node (dequant_node , 0 )
157159 if not is_param_node (ep , weight ):
160+ why (node , "Expected weight to be a static param" )
158161 return False , []
159162 gemm_deps .append (weight )
160163
161164 if is_per_channel (dequant_node ) or is_per_channel_group (dequant_node ):
162165 if len (dequant_node .all_input_nodes ) < 2 :
163166 # Expected channel quantized to have scale/zp nodes
167+ why (node , "Expected channel quantized to have scale/zp nodes" )
164168 return False , []
165169
170+ if is_per_tensor (dequant_node ) and precision == ConfigPrecisionType .DYNAMIC_QUANT :
171+ why (node , "XNNPACK does not support per tensor quantized weights for dynamic quantization of activations" )
172+ return False , []
173+
166174 gemm_deps .extend (dequant_node .all_input_nodes [1 :3 ])
167175 return (True , gemm_deps )
168176
@@ -174,7 +182,7 @@ def _get_output_deps(
174182 # Look for fused activations and tail end quant node
175183 node_users = list (node .users .keys ())
176184 if len (node_users ) != 1 :
177- # Expect quantized node to have a single output (fused act or dequant )
185+ why ( node , "Expected quantized node to have a single output" )
178186 return False , []
179187
180188 # Check if the quantized pattern has a fused activation
@@ -190,6 +198,7 @@ def _get_output_deps(
190198
191199 if not is_quant (n_output ):
192200 # Expected gemm_node --> fused_act (optional) --> dequant
201+ why (node , "Expected output node to have a dequantized node" )
193202 return (False , [])
194203 gemm_deps .append (n_output )
195204 elif precision == ConfigPrecisionType .FP32 :
@@ -219,7 +228,8 @@ def _get_bias_deps(
219228 bias_node = get_input_node (node , self .bias_idx )
220229 if bias_node :
221230 if not is_param_node (ep , bias_node ):
222- return (False , []) # bias node must be a static param
231+ why (node , "Expected bias to be a static param" )
232+ return (False , [])
223233 gemm_deps .append (bias_node )
224234
225235 return (True , gemm_deps )
@@ -233,7 +243,7 @@ def _get_act_deps(
233243 else :
234244 dq_input = get_input_node (node , self .act_idx )
235245 if not is_dequant (dq_input ):
236- # Expected static quant input to be dequant node
246+ why ( node , " Expected act input to be dequant node" )
237247 return False , []
238248 gemm_deps .append (dq_input )
239249 if precision == ConfigPrecisionType .STATIC_QUANT :
@@ -243,27 +253,28 @@ def _get_act_deps(
243253 # q input node
244254 q_input = get_input_node (dq_input , 0 )
245255 if not is_quant (q_input ):
256+ why (node , "Expected dequant input to be quant node" )
246257 return (False , [])
247258
248259 gemm_deps .append (q_input )
249260 q_input_args = q_input .args
250261 if is_affine_qdq (q_input ):
251262 q_input_args = extract_qdq_affine_op_args_for_decomposed_ops (q_input )
252263 if not (is_node (q_input_args [1 ]) and is_node (q_input_args [2 ])):
253- # expected to find getitem node from choose qparam
264+ why ( node , " expected to find getitem node from choose qparam" )
254265 return (False , [])
255266
256267 getitem1 = q_input_args [1 ]
257268 getitem2 = q_input_args [2 ]
258269
259270 if not (is_getitem (getitem1 ) and is_getitem (getitem2 )):
260- # expected getitem node from choose qparam
271+ why ( node , " expected getitem node from choose qparam" )
261272 return (False , [])
262273
263274 gemm_deps .extend ([getitem1 , getitem2 ])
264275 choose_qparam = get_input_node (getitem1 , 0 )
265276 if not is_qparam (choose_qparam ):
266- # expected to find choose_qparam node
277+ why ( node , " expected to find choose_qparam node" )
267278 return (False , [])
268279 gemm_deps .append (choose_qparam )
269280 return (True , gemm_deps )
@@ -471,6 +482,7 @@ def find_partition_args(input_node):
471482 # there can only be a single output node in partition
472483 or len (src_partition .output_nodes ) != 1
473484 ):
485+ why (node , "invalid source partition" )
474486 return (False , [])
475487
476488 # map addmm's args to the source partition linear's inputs and users
0 commit comments