Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
texture_limits: utils.ImageExtents,
buffer_limit: int,
require_dynamic_shape: bool = False,
skip_bool_tensors: bool = False,
operator_blocklist: Optional[Set[OpKey]] = None,
operator_allowlist: Optional[Set[OpKey]] = None,
fusable_subgraphs: Optional[List[PatternMatch]] = None,
Expand All @@ -69,6 +70,7 @@ def __init__(
self.texture_limits: utils.ImageExtents = texture_limits
self.buffer_limit = buffer_limit
self.require_dynamic_shapes = require_dynamic_shape
self.skip_bool_tensors = skip_bool_tensors
self.operator_blocklist: Set[OpKey] = (
operator_blocklist if operator_blocklist is not None else set()
)
Expand Down Expand Up @@ -117,6 +119,11 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
return False, "no operator implementation"
features = get_op_features(target)

# bool tensors are internally represented with int8 buffers, which may not be
# supported by some GPUs. Therefore, provide the option to skip these tensors.
if self.skip_bool_tensors and utils.op_contains_bool_tensor(node):
return False, f"op {utils.node_io_str(node)} contains bool tensor"

# Get the possible tensor representations for each tensor participating in the
# this operator. Then check that all tensors are representable as either a
# buffer or texture.
Expand Down Expand Up @@ -398,6 +405,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
texture_limits,
buffer_limit,
require_dynamic_shape=self.options.get("require_dynamic_shapes", False),
skip_bool_tensors=self.options.get("skip_bool_tensors", False),
operator_blocklist=self.operator_blocklist,
operator_allowlist=self.operator_allowlist,
fusable_subgraphs=fusable_subgraphs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ permute_buffer:
- VALUE: half
- VALUE: float
- VALUE: int32
- VALUE: uint8
shader_variants:
- NAME: permute_buffer
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ permute_texture:
- VALUE: half
- VALUE: float
- VALUE: int32
- VALUE: uint8
shader_variants:
- NAME: permute_texture3d
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ transfer_buffer:
- VALUE: half
- VALUE: float
- VALUE: int32
- VALUE: uint8
shader_variants:
- NAME: select_buffer
OP_NAME: select
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ transfer_texture:
- VALUE: half
- VALUE: float
- VALUE: int32
- VALUE: uint8
shader_variants:
- NAME: select_texture3d
OP_NAME: select
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ops/glsl/view.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ view:
- VALUE: half
- VALUE: float
- VALUE: int32
- VALUE: uint8
shader_variants:
- NAME: view
18 changes: 14 additions & 4 deletions backends/vulkan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,20 @@ def tensor_node_is_bool(node: torch.fx.Node) -> bool:
return False


def op_contains_bool_tensor(node: torch.fx.Node) -> bool:
"""
Returns true if the operator used to compute the given node contains a bool tensor
"""
if is_tensor_node(node) and tensor_node_is_bool(node):
return True

for arg_node in node.args:
if is_tensor_node(arg_node) and tensor_node_is_bool(arg_node):
return True

return False


def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]:
primary_arg_idx: Optional[int] = None
for i, arg_node in enumerate(node.args):
Expand Down Expand Up @@ -693,10 +707,6 @@ def make_filtered_tensor_repset(
if len(tensor_val.shape) > 4:
return TensorRepSet(tensor_repset.valid_buffer_layouts, set())

# Bool tensors are currently not supported
if tensor_val.dtype == torch.bool:
return NO_STORAGE

return TensorRepSet(tensor_repset.valid_buffer_layouts, valid_texture_layouts)


Expand Down
Loading