2121    is_dynamic_qdq ,
2222    is_per_channel ,
2323    is_per_channel_group ,
24+     is_per_tensor ,
2425    is_qparam ,
2526    is_quant ,
2627)
@@ -66,8 +67,6 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
6667            return  False 
6768
6869        is_valid , _  =  self .get_deps (node , ep )
69-         if  not  is_valid :
70-             why (node , "Failed to get valid dependent nodes." )
7170        return  is_valid 
7271
7372    def  get_node_and_deps (
@@ -123,6 +122,7 @@ def get_deps(
123122        precision  =  self ._detect_precision (node )
124123        if  precision  not  in   self .supported_precision_types ():
125124            # detected precision but it is either disabled or not supported 
125+             why (node , f"Unsupported precision type { precision }  " )
126126            return  (False , [])
127127        _ , precision  =  self ._overwrite_precision (node )
128128        valid_bias , bias_deps  =  self ._get_bias_deps (node , ep , precision )
@@ -143,26 +143,34 @@ def _get_weight_deps(
143143            # First find the weight 
144144            weight_node  =  get_input_node (node , self .weight_idx )
145145            if  not  is_param_node (ep , weight_node ):
146-                 return  (False , [])  # weight must be a static param 
146+                 why (node , "Expected weight to be a static param" )
147+                 return  (False , [])
147148            gemm_deps .append (weight_node )
148149
149150            return  (True , gemm_deps )
150151        else :
151152            # Quantized Weight deps 
152153            dequant_node  =  get_input_node (node , self .weight_idx )
153154            if  not  is_dequant (dequant_node ):
155+                 why (node , "Expected  weight to have a dequantized node" )
154156                return  False , []
155157            gemm_deps .append (dequant_node )
156158            weight  =  get_input_node (dequant_node , 0 )
157159            if  not  is_param_node (ep , weight ):
160+                 why (node , "Expected weight to be a static param" )
158161                return  False , []
159162            gemm_deps .append (weight )
160163
161164            if  is_per_channel (dequant_node ) or  is_per_channel_group (dequant_node ):
162165                if  len (dequant_node .all_input_nodes ) <  2 :
163166                    # Expected channel quantized to have scale/zp nodes 
167+                     why (node , "Expected channel quantized to have scale/zp nodes" )
164168                    return  False , []
165169
170+             if  is_per_tensor (dequant_node ) and  precision  ==  ConfigPrecisionType .DYNAMIC_QUANT :
171+                 why (node , "XNNPACK does not support per tensor quantized weights for dynamic quantization of activations" )
172+                 return  False , []
173+ 
166174                gemm_deps .extend (dequant_node .all_input_nodes [1 :3 ])
167175            return  (True , gemm_deps )
168176
@@ -174,7 +182,7 @@ def _get_output_deps(
174182            # Look for fused activations and tail end quant node 
175183            node_users  =  list (node .users .keys ())
176184            if  len (node_users ) !=  1 :
177-                 # Expect  quantized node to have a single output (fused act or dequant )
185+                 why ( node ,  "Expected  quantized node to have a single output"  )
178186                return  False , []
179187
180188            # Check if the quantized pattern has a fused activation 
@@ -190,6 +198,7 @@ def _get_output_deps(
190198
191199            if  not  is_quant (n_output ):
192200                # Expected gemm_node --> fused_act (optional) --> dequant 
201+                 why (node , "Expected output node to have a dequantized node" )
193202                return  (False , [])
194203            gemm_deps .append (n_output )
195204        elif  precision  ==  ConfigPrecisionType .FP32 :
@@ -219,7 +228,8 @@ def _get_bias_deps(
219228            bias_node  =  get_input_node (node , self .bias_idx )
220229            if  bias_node :
221230                if  not  is_param_node (ep , bias_node ):
222-                     return  (False , [])  # bias node must be a static param 
231+                     why (node , "Expected bias to be a static param" )
232+                     return  (False , [])
223233                gemm_deps .append (bias_node )
224234
225235        return  (True , gemm_deps )
@@ -233,7 +243,7 @@ def _get_act_deps(
233243        else :
234244            dq_input  =  get_input_node (node , self .act_idx )
235245            if  not  is_dequant (dq_input ):
236-                 #  Expected static quant  input to be dequant node
246+                 why ( node ,  " Expected act  input to be dequant node" ) 
237247                return  False , []
238248            gemm_deps .append (dq_input )
239249            if  precision  ==  ConfigPrecisionType .STATIC_QUANT :
@@ -243,27 +253,28 @@ def _get_act_deps(
243253            # q input node 
244254            q_input  =  get_input_node (dq_input , 0 )
245255            if  not  is_quant (q_input ):
256+                 why (node , "Expected  dequant input to be quant node" )
246257                return  (False , [])
247258
248259            gemm_deps .append (q_input )
249260            q_input_args  =  q_input .args 
250261            if  is_affine_qdq (q_input ):
251262                q_input_args  =  extract_qdq_affine_op_args_for_decomposed_ops (q_input )
252263            if  not  (is_node (q_input_args [1 ]) and  is_node (q_input_args [2 ])):
253-                 #  expected to find getitem node from choose qparam
264+                 why ( node ,  " expected to find getitem node from choose qparam" ) 
254265                return  (False , [])
255266
256267            getitem1  =  q_input_args [1 ]
257268            getitem2  =  q_input_args [2 ]
258269
259270            if  not  (is_getitem (getitem1 ) and  is_getitem (getitem2 )):
260-                 #  expected getitem node from choose qparam
271+                 why ( node ,  " expected getitem node from choose qparam" ) 
261272                return  (False , [])
262273
263274            gemm_deps .extend ([getitem1 , getitem2 ])
264275            choose_qparam  =  get_input_node (getitem1 , 0 )
265276            if  not  is_qparam (choose_qparam ):
266-                 #  expected to find choose_qparam node
277+                 why ( node ,  " expected to find choose_qparam node" ) 
267278                return  (False , [])
268279            gemm_deps .append (choose_qparam )
269280            return  (True , gemm_deps )
@@ -471,6 +482,7 @@ def find_partition_args(input_node):
471482            # there can only be a single output node in partition 
472483            or  len (src_partition .output_nodes ) !=  1 
473484        ):
485+             why (node , "invalid source partition" )
474486            return  (False , [])
475487
476488        # map addmm's args to the source partition linear's inputs and users 
0 commit comments