File tree Expand file tree Collapse file tree 2 files changed +20
-2
lines changed Expand file tree Collapse file tree 2 files changed +20
-2
lines changed Original file line number Diff line number Diff line change @@ -61,7 +61,7 @@ def __init__(
6161 self .buffer_limit = buffer_limit
6262 self .require_dynamic_shapes = require_dynamic_shape
6363
64- def op_node_is_compatible (
64+ def op_node_is_compatible ( # noqa: C901: Function is too complex
6565 self , node : torch .fx .Node , features : Optional [OpFeatures ] = None
6666 ) -> Tuple [bool , str ]:
6767 """
@@ -98,8 +98,12 @@ def op_node_is_compatible(
9898 and utils .is_tensor_node (arg )
9999 and i not in features .skip_limits_check
100100 ):
101+ # Check for bool inputs
102+ if utils .tensor_node_is_bool (arg ):
103+ return False , "contains bool tensor"
104+
101105 # Check for high dimensional tensors
102- if utils .is_tensor_node ( arg ) and utils . tensor_node_is_high_dim (arg ):
106+ if utils .tensor_node_is_high_dim (arg ):
103107 return False , "contains high dim tensor"
104108
105109 arg_texture_layouts = utils .possible_node_memory_layouts (
Original file line number Diff line number Diff line change @@ -80,6 +80,20 @@ def is_tensor_node(node: torch.fx.Node) -> bool:
8080 return False
8181
8282
83+ def tensor_node_is_bool (node : torch .fx .Node ) -> bool :
84+ """
85+ Returns true if a given node contains a tensor with bool dtype
86+ """
87+ if isinstance (node .meta ["val" ], FakeTensor ):
88+ return node .meta ["val" ].dtype == torch .bool
89+ if isinstance (node .meta ["val" ], list ) or isinstance (node .meta ["val" ], tuple ):
90+ for fake_tensor in node .meta ["val" ]:
91+ if isinstance (fake_tensor , FakeTensor ):
92+ if fake_tensor .dtype == torch .bool :
93+ return True
94+ return False
95+
96+
8397##
8498## Memory Layout, Storage Type Determination
8599##
You can’t perform that action at this time.
0 commit comments