Skip to content

Commit 82c0544

Browse files
author
ssjia
committed
[ET-VK][ez] Allow bool tensors to be lowered to ET-VK and add uint8(bool) dtype variants for several compute shaders
Pull Request resolved: #15155 Title says it all! ghstack-source-id: 317069711 @exported-using-ghexport Differential Revision: [D84716458](https://our.internmc.facebook.com/intern/diff/D84716458/)
1 parent 226300f commit 82c0544

File tree

7 files changed

+27
-4
lines changed

7 files changed

+27
-4
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
texture_limits: utils.ImageExtents,
6060
buffer_limit: int,
6161
require_dynamic_shape: bool = False,
62+
skip_bool_tensors: bool = False,
6263
operator_blocklist: Optional[Set[OpKey]] = None,
6364
operator_allowlist: Optional[Set[OpKey]] = None,
6465
fusable_subgraphs: Optional[List[PatternMatch]] = None,
@@ -69,6 +70,7 @@ def __init__(
6970
self.texture_limits: utils.ImageExtents = texture_limits
7071
self.buffer_limit = buffer_limit
7172
self.require_dynamic_shapes = require_dynamic_shape
73+
self.skip_bool_tensors = skip_bool_tensors
7274
self.operator_blocklist: Set[OpKey] = (
7375
operator_blocklist if operator_blocklist is not None else set()
7476
)
@@ -117,6 +119,11 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
117119
return False, "no operator implementation"
118120
features = get_op_features(target)
119121

122+
# bool tensors are internally represented with int8 buffers, which may not be
123+
# supported by some GPUs. Therefore, provide the option to skip these tensors.
124+
if self.skip_bool_tensors and utils.op_contains_bool_tensor(node):
125+
return False, f"op {utils.node_io_str(node)} contains bool tensor"
126+
120127
# Get the possible tensor representations for each tensor participating in the
121128
# this operator. Then check that all tensors are representable as either a
122129
# buffer or texture.
@@ -398,6 +405,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
398405
texture_limits,
399406
buffer_limit,
400407
require_dynamic_shape=self.options.get("require_dynamic_shapes", False),
408+
skip_bool_tensors=self.options.get("skip_bool_tensors", False),
401409
operator_blocklist=self.operator_blocklist,
402410
operator_allowlist=self.operator_allowlist,
403411
fusable_subgraphs=fusable_subgraphs,

backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ permute_buffer:
66
- VALUE: half
77
- VALUE: float
88
- VALUE: int32
9+
- VALUE: uint8
910
shader_variants:
1011
- NAME: permute_buffer

backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ permute_texture:
66
- VALUE: half
77
- VALUE: float
88
- VALUE: int32
9+
- VALUE: uint8
910
shader_variants:
1011
- NAME: permute_texture3d

backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ transfer_buffer:
88
- VALUE: half
99
- VALUE: float
1010
- VALUE: int32
11+
- VALUE: uint8
1112
shader_variants:
1213
- NAME: select_buffer
1314
OP_NAME: select

backends/vulkan/runtime/graph/ops/glsl/transfer_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ transfer_texture:
88
- VALUE: half
99
- VALUE: float
1010
- VALUE: int32
11+
- VALUE: uint8
1112
shader_variants:
1213
- NAME: select_texture3d
1314
OP_NAME: select

backends/vulkan/runtime/graph/ops/glsl/view.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ view:
88
- VALUE: half
99
- VALUE: float
1010
- VALUE: int32
11+
- VALUE: uint8
1112
shader_variants:
1213
- NAME: view

backends/vulkan/utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,20 @@ def tensor_node_is_bool(node: torch.fx.Node) -> bool:
259259
return False
260260

261261

262+
def op_contains_bool_tensor(node: torch.fx.Node) -> bool:
263+
"""
264+
Returns true if the operator used to compute the given node contains a bool tensor
265+
"""
266+
if is_tensor_node(node) and tensor_node_is_bool(node):
267+
return True
268+
269+
for arg_node in node.args:
270+
if is_tensor_node(arg_node) and tensor_node_is_bool(arg_node):
271+
return True
272+
273+
return False
274+
275+
262276
def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]:
263277
primary_arg_idx: Optional[int] = None
264278
for i, arg_node in enumerate(node.args):
@@ -693,10 +707,6 @@ def make_filtered_tensor_repset(
693707
if len(tensor_val.shape) > 4:
694708
return TensorRepSet(tensor_repset.valid_buffer_layouts, set())
695709

696-
# Bool tensors are currently not supported
697-
if tensor_val.dtype == torch.bool:
698-
return NO_STORAGE
699-
700710
return TensorRepSet(tensor_repset.valid_buffer_layouts, valid_texture_layouts)
701711

702712

0 commit comments

Comments
 (0)