File tree Expand file tree Collapse file tree 2 files changed +20
-2
lines changed Expand file tree Collapse file tree 2 files changed +20
-2
lines changed Original file line number Diff line number Diff line change @@ -61,7 +61,7 @@ def __init__(
6161        self .buffer_limit  =  buffer_limit 
6262        self .require_dynamic_shapes  =  require_dynamic_shape 
6363
64-     def  op_node_is_compatible (
64+     def  op_node_is_compatible (   # noqa: C901: Function is too complex 
6565        self , node : torch .fx .Node , features : Optional [OpFeatures ] =  None 
6666    ) ->  Tuple [bool , str ]:
6767        """ 
@@ -98,8 +98,12 @@ def op_node_is_compatible(
9898                and  utils .is_tensor_node (arg )
9999                and  i  not  in   features .skip_limits_check 
100100            ):
101+                 # Check for bool inputs 
102+                 if  utils .tensor_node_is_bool (arg ):
103+                     return  False , "contains bool tensor" 
104+ 
101105                # Check for high dimensional tensors 
102-                 if  utils .is_tensor_node ( arg )  and   utils . tensor_node_is_high_dim (arg ):
106+                 if  utils .tensor_node_is_high_dim (arg ):
103107                    return  False , "contains high dim tensor" 
104108
105109                arg_texture_layouts  =  utils .possible_node_memory_layouts (
Original file line number Diff line number Diff line change @@ -80,6 +80,20 @@ def is_tensor_node(node: torch.fx.Node) -> bool:
8080    return  False 
8181
8282
83+ def  tensor_node_is_bool (node : torch .fx .Node ) ->  bool :
84+     """ 
85+     Returns true if a given node contains a tensor with bool dtype 
86+     """ 
87+     if  isinstance (node .meta ["val" ], FakeTensor ):
88+         return  node .meta ["val" ].dtype  ==  torch .bool 
89+     if  isinstance (node .meta ["val" ], list ) or  isinstance (node .meta ["val" ], tuple ):
90+         for  fake_tensor  in  node .meta ["val" ]:
91+             if  isinstance (fake_tensor , FakeTensor ):
92+                 if  fake_tensor .dtype  ==  torch .bool :
93+                     return  True 
94+     return  False 
95+ 
96+ 
8397## 
8498## Memory Layout, Storage Type Determination 
8599## 
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments