Skip to content

Commit a566c10

Browse files
nathanaelseefacebook-github-bot
authored andcommitted
skip op in partitioning if there are bool input tensors
Summary: Vulkan backend does not support bool tensors Differential Revision: D69273733
1 parent b362ab7 commit a566c10

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ def op_node_is_compatible(
9898
and utils.is_tensor_node(arg)
9999
and i not in features.skip_limits_check
100100
):
101+
# Check for bool inputs
102+
if utils.is_tensor_node(arg) and utils.tensor_node_is_bool(arg):
103+
return False, "contains bool tensor"
104+
101105
# Check for high dimensional tensors
102106
if utils.is_tensor_node(arg) and utils.tensor_node_is_high_dim(arg):
103107
return False, "contains high dim tensor"

backends/vulkan/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,19 @@ def is_tensor_node(node: torch.fx.Node) -> bool:
8080
return False
8181

8282

83+
def tensor_node_is_bool(node: torch.fx.Node) -> bool:
84+
"""
85+
Returns true if a given node contains a tensor with bool dtype
86+
"""
87+
if isinstance(node.meta["val"], FakeTensor):
88+
return node.meta["val"].dtype == torch.bool
89+
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
90+
for fake_tensor in node.meta["val"]:
91+
if isinstance(fake_tensor, FakeTensor):
92+
if fake_tensor.dtype == torch.bool:
93+
return True
94+
return False
95+
8396
##
8497
## Memory Layout, Storage Type Determination
8598
##

0 commit comments

Comments
 (0)