Skip to content

Commit 2322c03

Browse files
author
ssjia
committed
[ET-VK][ez] Allow bool tensors to be lowered to ET-VK and add uint8(bool) dtype variants for several compute shaders
Title says it all! Differential Revision: [D84716458](https://our.internmc.facebook.com/intern/diff/D84716458/) [ghstack-poisoned]
1 parent 535e916 commit 2322c03

File tree

3 files changed

+2
-4
lines changed

3 files changed

+2
-4
lines changed

backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ permute_buffer:
66
- VALUE: half
77
- VALUE: float
88
- VALUE: int32
9+
- VALUE: uint8
910
shader_variants:
1011
- NAME: permute_buffer

backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ permute_texture:
66
- VALUE: half
77
- VALUE: float
88
- VALUE: int32
9+
- VALUE: uint8
910
shader_variants:
1011
- NAME: permute_texture3d

backends/vulkan/utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -693,10 +693,6 @@ def make_filtered_tensor_repset(
693693
if len(tensor_val.shape) > 4:
694694
return TensorRepSet(tensor_repset.valid_buffer_layouts, set())
695695

696-
# Bool tensors are currently not supported
697-
if tensor_val.dtype == torch.bool:
698-
return NO_STORAGE
699-
700696
return TensorRepSet(tensor_repset.valid_buffer_layouts, valid_texture_layouts)
701697

702698

0 commit comments

Comments
 (0)