Skip to content

Commit a6c5921

Browse files
SS-JIAssjia
andauthored
[ET-VK] Add int and bool tensor support for many operators (pytorch#15829)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ pytorch#15829 * pytorch#15796 * pytorch#15795 * pytorch#15794 * pytorch#15793 Title says it all! Adds `int32` and `uint8` shader variants to a bunch of operators that don't currently have variants for these dtypes, but should. This should prevent folks from running into dtype crashes at runtime when using the Vulkan delegate. Differential Revision: [D87082724](https://our.internmc.facebook.com/intern/diff/D87082724/) Co-authored-by: ssjia <[email protected]>
1 parent 053193f commit a6c5921

16 files changed

+24
-4
lines changed

backends/vulkan/runtime/graph/ops/glsl/clone.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,7 @@ clone:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10+
- VALUE: int32
11+
- VALUE: uint8
1012
shader_variants:
1113
- NAME: clone

backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ concat_buffer:
66
DTYPE:
77
- VALUE: half
88
- VALUE: float
9+
- VALUE: int32
910
shader_variants:
1011
- NAME: concat_1_buffer
1112
NUM_INPUTS: 1

backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@ void main() {
113113

114114
VEC4_T out_texel = imageLoad(t_out, out_pos);
115115

116-
VEC4_T test_texel = VEC4_T(-1.0);
117-
118116
for (int comp = 0; comp < 4; ++comp) {
119117
ivec4 out_tidx = out_read_start_tidx;
120118
out_tidx[out_packed_dim] += comp;
@@ -124,7 +122,6 @@ void main() {
124122
// of the previous input batch; if so, then don't overwrite this texel
125123
// element
126124
if (out_tidx[concat_dim] < concat_offset) {
127-
test_texel[comp] = -5.0;
128125
continue;
129126
}
130127

@@ -164,7 +161,6 @@ void main() {
164161
inp${i}_packed_dim);
165162

166163
out_texel[comp] = texelFetch(t_inp${i}, in_posi.xyz, 0)[in_posi.w];
167-
test_texel[comp] = out_texel[comp];
168164
continue;
169165
}
170166
else {

backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ concat_texture:
66
DTYPE:
77
- VALUE: half
88
- VALUE: float
9+
- VALUE: int32
910
shader_variants:
1011
- NAME: concat_1_texture3d
1112
NUM_INPUTS: 1

backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ expand_buffer:
66
- VALUE: half
77
- VALUE: float
88
- VALUE: int32
9+
- VALUE: uint8
910
shader_variants:
1011
- NAME: expand_buffer

backends/vulkan/runtime/graph/ops/glsl/full.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ full:
1515
- VALUE: half
1616
- VALUE: float
1717
- VALUE: int32
18+
- VALUE: uint8
1819
shader_variants:
1920
- NAME: full

backends/vulkan/runtime/graph/ops/glsl/gather_buffer.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,7 @@ gather_buffer:
1212
DTYPE:
1313
- VALUE: half
1414
- VALUE: float
15+
- VALUE: int32
16+
- VALUE: uint8
1517
shader_variants:
1618
- NAME: gather_buffer

backends/vulkan/runtime/graph/ops/glsl/gather_texture.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,7 @@ gather_texture:
1111
DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: int32
15+
- VALUE: uint8
1416
shader_variants:
1517
- NAME: gather_texture3d

backends/vulkan/runtime/graph/ops/glsl/index_select.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ index_select:
88
- VALUE: half
99
- VALUE: float
1010
- VALUE: int32
11+
- VALUE: uint8
1112
shader_variants:
1213
- NAME: index_select

backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ index_select_channel:
88
- VALUE: half
99
- VALUE: float
1010
- VALUE: int32
11+
- VALUE: uint8
1112
shader_variants:
1213
- NAME: index_select_channel

0 commit comments

Comments
 (0)