2121 is_dynamic_qdq ,
2222 is_per_channel ,
2323 is_per_channel_group ,
24+ is_per_tensor ,
2425 is_qparam ,
2526 is_quant ,
2627)
@@ -66,8 +67,6 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
6667 return False
6768
6869 is_valid , _ = self .get_deps (node , ep )
69- if not is_valid :
70- why (node , "Failed to get valid dependent nodes." )
7170 return is_valid
7271
7372 def get_node_and_deps (
@@ -123,6 +122,7 @@ def get_deps(
123122 precision = self ._detect_precision (node )
124123 if precision not in self .supported_precision_types ():
125124 # detected precision but it is either disabled or not supported
125+ why (node , f"Unsupported precision type { precision } " )
126126 return (False , [])
127127 _ , precision = self ._overwrite_precision (node )
128128 valid_bias , bias_deps = self ._get_bias_deps (node , ep , precision )
@@ -143,27 +143,42 @@ def _get_weight_deps(
143143 # First find the weight
144144 weight_node = get_input_node (node , self .weight_idx )
145145 if not is_param_node (ep , weight_node ):
146- return (False , []) # weight must be a static param
146+ why (node , "Expected weight to be a static param" )
147+ return (False , [])
147148 gemm_deps .append (weight_node )
148149
149150 return (True , gemm_deps )
150151 else :
151152 # Quantized Weight deps
152153 dequant_node = get_input_node (node , self .weight_idx )
153154 if not is_dequant (dequant_node ):
155+ why (node , "Expected weight to have a dequantized node" )
154156 return False , []
155157 gemm_deps .append (dequant_node )
156158 weight = get_input_node (dequant_node , 0 )
157159 if not is_param_node (ep , weight ):
160+ why (node , "Expected weight to be a static param" )
158161 return False , []
159162 gemm_deps .append (weight )
160163
164+ if (
165+ is_per_tensor (dequant_node )
166+ and precision == ConfigPrecisionType .DYNAMIC_QUANT
167+ ):
168+ why (
169+ node ,
170+ "XNNPACK does not support per tensor quantized weights for dynamic quantization of activations" ,
171+ )
172+ return False , []
173+
161174 if is_per_channel (dequant_node ) or is_per_channel_group (dequant_node ):
162175 if len (dequant_node .all_input_nodes ) < 2 :
163176 # Expected channel quantized to have scale/zp nodes
177+ why (node , "Expected channel quantized to have scale/zp nodes" )
164178 return False , []
165179
166180 gemm_deps .extend (dequant_node .all_input_nodes [1 :3 ])
181+
167182 return (True , gemm_deps )
168183
169184 def _get_output_deps (
@@ -174,7 +189,7 @@ def _get_output_deps(
174189 # Look for fused activations and tail end quant node
175190 node_users = list (node .users .keys ())
176191 if len (node_users ) != 1 :
177- # Expect quantized node to have a single output (fused act or dequant )
192+ why ( node , "Expected quantized node to have a single output" )
178193 return False , []
179194
180195 # Check if the quantized pattern has a fused activation
@@ -190,6 +205,7 @@ def _get_output_deps(
190205
191206 if not is_quant (n_output ):
192207 # Expected gemm_node --> fused_act (optional) --> dequant
208+ why (node , "Expected output node to have a dequantized node" )
193209 return (False , [])
194210 gemm_deps .append (n_output )
195211 elif precision == ConfigPrecisionType .FP32 :
@@ -219,7 +235,8 @@ def _get_bias_deps(
219235 bias_node = get_input_node (node , self .bias_idx )
220236 if bias_node :
221237 if not is_param_node (ep , bias_node ):
222- return (False , []) # bias node must be a static param
238+ why (node , "Expected bias to be a static param" )
239+ return (False , [])
223240 gemm_deps .append (bias_node )
224241
225242 return (True , gemm_deps )
@@ -233,7 +250,7 @@ def _get_act_deps(
233250 else :
234251 dq_input = get_input_node (node , self .act_idx )
235252 if not is_dequant (dq_input ):
236- # Expected static quant input to be dequant node
253+ why ( node , " Expected act input to be dequant node" )
237254 return False , []
238255 gemm_deps .append (dq_input )
239256 if precision == ConfigPrecisionType .STATIC_QUANT :
@@ -243,27 +260,28 @@ def _get_act_deps(
243260 # q input node
244261 q_input = get_input_node (dq_input , 0 )
245262 if not is_quant (q_input ):
263+ why (node , "Expected dequant input to be quant node" )
246264 return (False , [])
247265
248266 gemm_deps .append (q_input )
249267 q_input_args = q_input .args
250268 if is_affine_qdq (q_input ):
251269 q_input_args = extract_qdq_affine_op_args_for_decomposed_ops (q_input )
252270 if not (is_node (q_input_args [1 ]) and is_node (q_input_args [2 ])):
253- # expected to find getitem node from choose qparam
271+ why ( node , " expected to find getitem node from choose qparam" )
254272 return (False , [])
255273
256274 getitem1 = q_input_args [1 ]
257275 getitem2 = q_input_args [2 ]
258276
259277 if not (is_getitem (getitem1 ) and is_getitem (getitem2 )):
260- # expected getitem node from choose qparam
278+ why ( node , " expected getitem node from choose qparam" )
261279 return (False , [])
262280
263281 gemm_deps .extend ([getitem1 , getitem2 ])
264282 choose_qparam = get_input_node (getitem1 , 0 )
265283 if not is_qparam (choose_qparam ):
266- # expected to find choose_qparam node
284+ why ( node , " expected to find choose_qparam node" )
267285 return (False , [])
268286 gemm_deps .append (choose_qparam )
269287 return (True , gemm_deps )
@@ -471,6 +489,7 @@ def find_partition_args(input_node):
471489 # there can only be a single output node in partition
472490 or len (src_partition .output_nodes ) != 1
473491 ):
492+ why (node , "invalid source partition" )
474493 return (False , [])
475494
476495 # map addmm's args to the source partition linear's inputs and users
0 commit comments