Skip to content

Commit 5118de4

Browse files
SS-JIAssjia
andauthored
[ET-VK] Add int and bool tensor support for many operators (#15834)
Original commit message: Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #15829 * #15796 * #15795 * #15794 * #15793 Title says it all! Adds `int32` and `uint8` shader variants to a bunch of operators that don't currently have variants for these dtypes, but should. This should prevent folks from running into dtype crashes at runtime when using the Vulkan delegate. Differential Revision: [D87082724](https://our.internmc.facebook.com/intern/diff/D87082724/) (cherry picked from commit a6c5921) ### Summary [PLEASE REMOVE] See [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests) for ExecuTorch PR guidelines. [PLEASE REMOVE] If this PR closes an issue, please add a `Fixes #<issue-id>` line. [PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: <area>" label. For a list of available release notes labels, check out [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests). ### Test plan [PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable. Co-authored-by: ssjia <[email protected]>
1 parent 51c99f6 commit 5118de4

File tree

12 files changed

+16
-4
lines changed

12 files changed

+16
-4
lines changed

backends/vulkan/runtime/graph/ops/glsl/clone.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,7 @@ clone:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10+
- VALUE: int32
11+
- VALUE: uint8
1012
shader_variants:
1113
- NAME: clone

backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ concat_buffer:
66
DTYPE:
77
- VALUE: half
88
- VALUE: float
9+
- VALUE: int32
910
shader_variants:
1011
- NAME: concat_1_buffer
1112
NUM_INPUTS: 1

backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@ void main() {
113113

114114
VEC4_T out_texel = imageLoad(t_out, out_pos);
115115

116-
VEC4_T test_texel = VEC4_T(-1.0);
117-
118116
for (int comp = 0; comp < 4; ++comp) {
119117
ivec4 out_tidx = out_read_start_tidx;
120118
out_tidx[out_packed_dim] += comp;
@@ -124,7 +122,6 @@ void main() {
124122
// of the previous input batch; if so, then don't overwrite this texel
125123
// element
126124
if (out_tidx[concat_dim] < concat_offset) {
127-
test_texel[comp] = -5.0;
128125
continue;
129126
}
130127

@@ -164,7 +161,6 @@ void main() {
164161
inp${i}_packed_dim);
165162

166163
out_texel[comp] = texelFetch(t_inp${i}, in_posi.xyz, 0)[in_posi.w];
167-
test_texel[comp] = out_texel[comp];
168164
continue;
169165
}
170166
else {

backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ concat_texture:
66
DTYPE:
77
- VALUE: half
88
- VALUE: float
9+
- VALUE: int32
910
shader_variants:
1011
- NAME: concat_1_texture3d
1112
NUM_INPUTS: 1

backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ expand_buffer:
66
- VALUE: half
77
- VALUE: float
88
- VALUE: int32
9+
- VALUE: uint8
910
shader_variants:
1011
- NAME: expand_buffer

backends/vulkan/runtime/graph/ops/glsl/full.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ full:
1515
- VALUE: half
1616
- VALUE: float
1717
- VALUE: int32
18+
- VALUE: uint8
1819
shader_variants:
1920
- NAME: full

backends/vulkan/runtime/graph/ops/glsl/index_select.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ index_select:
88
- VALUE: half
99
- VALUE: float
1010
- VALUE: int32
11+
- VALUE: uint8
1112
shader_variants:
1213
- NAME: index_select

backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ index_select_channel:
88
- VALUE: half
99
- VALUE: float
1010
- VALUE: int32
11+
- VALUE: uint8
1112
shader_variants:
1213
- NAME: index_select_channel

backends/vulkan/runtime/graph/ops/glsl/pad_channel.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,7 @@ pad_channel:
88
DTYPE:
99
- VALUE: float
1010
- VALUE: half
11+
- VALUE: int32
12+
- VALUE: uint8
1113
shader_variants:
1214
- NAME: pad_channel

backends/vulkan/runtime/graph/ops/glsl/pad_height_width.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,7 @@ pad_height_width:
88
DTYPE:
99
- VALUE: float
1010
- VALUE: half
11+
- VALUE: int32
12+
- VALUE: uint8
1113
shader_variants:
1214
- NAME: pad_height_width

0 commit comments

Comments
 (0)