2121 is_dynamic_qdq ,
2222 is_per_channel ,
2323 is_per_channel_group ,
24+ is_per_tensor ,
2425 is_qparam ,
2526 is_quant ,
2627)
@@ -66,8 +67,6 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
6667 return False
6768
6869 is_valid , _ = self .get_deps (node , ep )
69- if not is_valid :
70- why (node , "Failed to get valid dependent nodes." )
7170 return is_valid
7271
7372 def get_node_and_deps (
@@ -123,6 +122,7 @@ def get_deps(
123122 precision = self ._detect_precision (node )
124123 if precision not in self .supported_precision_types ():
125124 # detected precision but it is either disabled or not supported
125+ why (node , f"Unsupported precision type { precision } " )
126126 return (False , [])
127127 _ , precision = self ._overwrite_precision (node )
128128 valid_bias , bias_deps = self ._get_bias_deps (node , ep , precision )
@@ -143,27 +143,36 @@ def _get_weight_deps(
143143 # First find the weight
144144 weight_node = get_input_node (node , self .weight_idx )
145145 if not is_param_node (ep , weight_node ):
146- return (False , []) # weight must be a static param
146+ why (node , "Expected weight to be a static param" )
147+ return (False , [])
147148 gemm_deps .append (weight_node )
148149
149150 return (True , gemm_deps )
150151 else :
151152 # Quantized Weight deps
152153 dequant_node = get_input_node (node , self .weight_idx )
153154 if not is_dequant (dequant_node ):
155+ why (node , "Expected weight to have a dequantized node" )
154156 return False , []
155157 gemm_deps .append (dequant_node )
156158 weight = get_input_node (dequant_node , 0 )
157159 if not is_param_node (ep , weight ):
160+ why (node , "Expected weight to be a static param" )
158161 return False , []
159162 gemm_deps .append (weight )
160163
164+ if is_per_tensor (dequant_node ) and precision == ConfigPrecisionType .DYNAMIC_QUANT :
165+ why (node , "XNNPACK does not support per tensor quantized weights for dynamic quantization of activations" )
166+ return False , []
167+
161168 if is_per_channel (dequant_node ) or is_per_channel_group (dequant_node ):
162169 if len (dequant_node .all_input_nodes ) < 2 :
163170 # Expected channel quantized to have scale/zp nodes
171+ why (node , "Expected channel quantized to have scale/zp nodes" )
164172 return False , []
165173
166174 gemm_deps .extend (dequant_node .all_input_nodes [1 :3 ])
175+
167176 return (True , gemm_deps )
168177
169178 def _get_output_deps (
@@ -174,7 +183,7 @@ def _get_output_deps(
174183 # Look for fused activations and tail end quant node
175184 node_users = list (node .users .keys ())
176185 if len (node_users ) != 1 :
177- # Expect quantized node to have a single output (fused act or dequant )
186+ why ( node , "Expected quantized node to have a single output" )
178187 return False , []
179188
180189 # Check if the quantized pattern has a fused activation
@@ -190,6 +199,7 @@ def _get_output_deps(
190199
191200 if not is_quant (n_output ):
192201 # Expected gemm_node --> fused_act (optional) --> dequant
202+ why (node , "Expected output node to have a dequantized node" )
193203 return (False , [])
194204 gemm_deps .append (n_output )
195205 elif precision == ConfigPrecisionType .FP32 :
@@ -219,7 +229,8 @@ def _get_bias_deps(
219229 bias_node = get_input_node (node , self .bias_idx )
220230 if bias_node :
221231 if not is_param_node (ep , bias_node ):
222- return (False , []) # bias node must be a static param
232+ why (node , "Expected bias to be a static param" )
233+ return (False , [])
223234 gemm_deps .append (bias_node )
224235
225236 return (True , gemm_deps )
@@ -233,7 +244,7 @@ def _get_act_deps(
233244 else :
234245 dq_input = get_input_node (node , self .act_idx )
235246 if not is_dequant (dq_input ):
236- # Expected static quant input to be dequant node
247+ why ( node , " Expected act input to be dequant node" )
237248 return False , []
238249 gemm_deps .append (dq_input )
239250 if precision == ConfigPrecisionType .STATIC_QUANT :
@@ -243,27 +254,28 @@ def _get_act_deps(
243254 # q input node
244255 q_input = get_input_node (dq_input , 0 )
245256 if not is_quant (q_input ):
257+ why (node , "Expected dequant input to be quant node" )
246258 return (False , [])
247259
248260 gemm_deps .append (q_input )
249261 q_input_args = q_input .args
250262 if is_affine_qdq (q_input ):
251263 q_input_args = extract_qdq_affine_op_args_for_decomposed_ops (q_input )
252264 if not (is_node (q_input_args [1 ]) and is_node (q_input_args [2 ])):
253- # expected to find getitem node from choose qparam
265+ why ( node , " expected to find getitem node from choose qparam" )
254266 return (False , [])
255267
256268 getitem1 = q_input_args [1 ]
257269 getitem2 = q_input_args [2 ]
258270
259271 if not (is_getitem (getitem1 ) and is_getitem (getitem2 )):
260- # expected getitem node from choose qparam
272+ why ( node , " expected getitem node from choose qparam" )
261273 return (False , [])
262274
263275 gemm_deps .extend ([getitem1 , getitem2 ])
264276 choose_qparam = get_input_node (getitem1 , 0 )
265277 if not is_qparam (choose_qparam ):
266- # expected to find choose_qparam node
278+ why ( node , " expected to find choose_qparam node" )
267279 return (False , [])
268280 gemm_deps .append (choose_qparam )
269281 return (True , gemm_deps )
@@ -471,6 +483,7 @@ def find_partition_args(input_node):
471483 # there can only be a single output node in partition
472484 or len (src_partition .output_nodes ) != 1
473485 ):
486+ why (node , "invalid source partition" )
474487 return (False , [])
475488
476489 # map addmm's args to the source partition linear's inputs and users
0 commit comments