2121    is_dynamic_qdq ,
2222    is_per_channel ,
2323    is_per_channel_group ,
24+     is_per_tensor ,
2425    is_qparam ,
2526    is_quant ,
2627)
@@ -66,8 +67,6 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
6667            return  False 
6768
6869        is_valid , _  =  self .get_deps (node , ep )
69-         if  not  is_valid :
70-             why (node , "Failed to get valid dependent nodes." )
7170        return  is_valid 
7271
7372    def  get_node_and_deps (
@@ -123,6 +122,7 @@ def get_deps(
123122        precision  =  self ._detect_precision (node )
124123        if  precision  not  in   self .supported_precision_types ():
125124            # detected precision but it is either disabled or not supported 
125+             why (node , f"Unsupported precision type { precision }  " )
126126            return  (False , [])
127127        _ , precision  =  self ._overwrite_precision (node )
128128        valid_bias , bias_deps  =  self ._get_bias_deps (node , ep , precision )
@@ -143,27 +143,42 @@ def _get_weight_deps(
143143            # First find the weight 
144144            weight_node  =  get_input_node (node , self .weight_idx )
145145            if  not  is_param_node (ep , weight_node ):
146-                 return  (False , [])  # weight must be a static param 
146+                 why (node , "Expected weight to be a static param" )
147+                 return  (False , [])
147148            gemm_deps .append (weight_node )
148149
149150            return  (True , gemm_deps )
150151        else :
151152            # Quantized Weight deps 
152153            dequant_node  =  get_input_node (node , self .weight_idx )
153154            if  not  is_dequant (dequant_node ):
155+                 why (node , "Expected  weight to have a dequantized node" )
154156                return  False , []
155157            gemm_deps .append (dequant_node )
156158            weight  =  get_input_node (dequant_node , 0 )
157159            if  not  is_param_node (ep , weight ):
160+                 why (node , "Expected weight to be a static param" )
158161                return  False , []
159162            gemm_deps .append (weight )
160163
164+             if  (
165+                 is_per_tensor (dequant_node )
166+                 and  precision  ==  ConfigPrecisionType .DYNAMIC_QUANT 
167+             ):
168+                 why (
169+                     node ,
170+                     "XNNPACK does not support per tensor quantized weights for dynamic quantization of activations" ,
171+                 )
172+                 return  False , []
173+ 
161174            if  is_per_channel (dequant_node ) or  is_per_channel_group (dequant_node ):
162175                if  len (dequant_node .all_input_nodes ) <  2 :
163176                    # Expected channel quantized to have scale/zp nodes 
177+                     why (node , "Expected channel quantized to have scale/zp nodes" )
164178                    return  False , []
165179
166180                gemm_deps .extend (dequant_node .all_input_nodes [1 :3 ])
181+ 
167182            return  (True , gemm_deps )
168183
169184    def  _get_output_deps (
@@ -174,7 +189,7 @@ def _get_output_deps(
174189            # Look for fused activations and tail end quant node 
175190            node_users  =  list (node .users .keys ())
176191            if  len (node_users ) !=  1 :
177-                 # Expect  quantized node to have a single output (fused act or dequant )
192+                 why ( node ,  "Expected  quantized node to have a single output"  )
178193                return  False , []
179194
180195            # Check if the quantized pattern has a fused activation 
@@ -190,6 +205,7 @@ def _get_output_deps(
190205
191206            if  not  is_quant (n_output ):
192207                # Expected gemm_node --> fused_act (optional) --> dequant 
208+                 why (node , "Expected output node to have a dequantized node" )
193209                return  (False , [])
194210            gemm_deps .append (n_output )
195211        elif  precision  ==  ConfigPrecisionType .FP32 :
@@ -219,7 +235,8 @@ def _get_bias_deps(
219235            bias_node  =  get_input_node (node , self .bias_idx )
220236            if  bias_node :
221237                if  not  is_param_node (ep , bias_node ):
222-                     return  (False , [])  # bias node must be a static param 
238+                     why (node , "Expected bias to be a static param" )
239+                     return  (False , [])
223240                gemm_deps .append (bias_node )
224241
225242        return  (True , gemm_deps )
@@ -233,7 +250,7 @@ def _get_act_deps(
233250        else :
234251            dq_input  =  get_input_node (node , self .act_idx )
235252            if  not  is_dequant (dq_input ):
236-                 #  Expected static quant  input to be dequant node
253+                 why ( node ,  " Expected act  input to be dequant node" ) 
237254                return  False , []
238255            gemm_deps .append (dq_input )
239256            if  precision  ==  ConfigPrecisionType .STATIC_QUANT :
@@ -243,27 +260,28 @@ def _get_act_deps(
243260            # q input node 
244261            q_input  =  get_input_node (dq_input , 0 )
245262            if  not  is_quant (q_input ):
263+                 why (node , "Expected  dequant input to be quant node" )
246264                return  (False , [])
247265
248266            gemm_deps .append (q_input )
249267            q_input_args  =  q_input .args 
250268            if  is_affine_qdq (q_input ):
251269                q_input_args  =  extract_qdq_affine_op_args_for_decomposed_ops (q_input )
252270            if  not  (is_node (q_input_args [1 ]) and  is_node (q_input_args [2 ])):
253-                 #  expected to find getitem node from choose qparam
271+                 why ( node ,  " expected to find getitem node from choose qparam" ) 
254272                return  (False , [])
255273
256274            getitem1  =  q_input_args [1 ]
257275            getitem2  =  q_input_args [2 ]
258276
259277            if  not  (is_getitem (getitem1 ) and  is_getitem (getitem2 )):
260-                 #  expected getitem node from choose qparam
278+                 why ( node ,  " expected getitem node from choose qparam" ) 
261279                return  (False , [])
262280
263281            gemm_deps .extend ([getitem1 , getitem2 ])
264282            choose_qparam  =  get_input_node (getitem1 , 0 )
265283            if  not  is_qparam (choose_qparam ):
266-                 #  expected to find choose_qparam node
284+                 why ( node ,  " expected to find choose_qparam node" ) 
267285                return  (False , [])
268286            gemm_deps .append (choose_qparam )
269287            return  (True , gemm_deps )
@@ -471,6 +489,7 @@ def find_partition_args(input_node):
471489            # there can only be a single output node in partition 
472490            or  len (src_partition .output_nodes ) !=  1 
473491        ):
492+             why (node , "invalid source partition" )
474493            return  (False , [])
475494
476495        # map addmm's args to the source partition linear's inputs and users 
0 commit comments