Skip to content

Commit f55a8bf

Browse files
author
Nathanael See
committed
[BE][ET-VK] update max_pool2d to use layout gen
Pull Request resolved: #9591 TSIA @pytorchbot label "topic: not user facing" Differential Revision: [D71825476](https://our.internmc.facebook.com/intern/diff/D71825476/) ghstack-source-id: 274222178
1 parent bd12621 commit f55a8bf

File tree

2 files changed

+10
-21
lines changed

2 files changed

+10
-21
lines changed

backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,12 @@
1515

1616
layout(std430) buffer;
1717

18-
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
19-
layout(set = 0, binding = 1, ${IMAGE_FORMAT["int"]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM]["int"]} image_idx;
20-
layout(set = 0, binding = 2) uniform PRECISION sampler3D image_in;
21-
22-
layout(set = 0, binding = 3) uniform PRECISION restrict OutLimits {
23-
ivec3 out_limits;
24-
};
25-
26-
layout(set = 0, binding = 4) uniform PRECISION restrict InSizes {
27-
ivec4 in_sizes;
28-
};
29-
30-
layout(set = 0, binding = 5) uniform PRECISION restrict Params {
31-
ivec2 kernel_size;
32-
ivec2 stride;
33-
ivec2 padding;
34-
ivec2 dilation;
35-
};
18+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
19+
${layout_declare_tensor(B, "w", "t_idx", "int", STORAGE)}
20+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
21+
${layout_declare_ubo(B, "ivec3", "out_limits")}
22+
${layout_declare_ubo(B, "ivec4", "in_sizes")}
23+
${layout_declare_ubo(B, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")}
3624

3725
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3826

@@ -54,7 +42,7 @@ void main() {
5442
for (int y = start.y; y < end.y; y += dilation.y) {
5543
for (int x = start.x; x < end.x; x += dilation.x) {
5644
if ((x >= 0 && x < in_sizes.x) && (y >= 0 && y < in_sizes.y)) {
57-
const vec4 cur_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0);
45+
const vec4 cur_texel = load_texel(t_in, ivec3(x, y, pos.z));
5846

5947
// Set idx if value is greatest in the pool; else, keep the existing idx.
6048
ivec4 cur_idx = ivec4(x + int(in_sizes.x) * y);
@@ -66,6 +54,6 @@ void main() {
6654
}
6755
}
6856

69-
imageStore(image_out, pos, out_texel);
70-
imageStore(image_idx, pos, idx_texel);
57+
imageStore(t_out, pos, out_texel);
58+
imageStore(t_idx, pos, idx_texel);
7159
}

backends/vulkan/runtime/graph/ops/glsl/max_pool2d.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ max_pool2d:
88
parameter_names_with_default_values:
99
NDIM: 3
1010
DTYPE: float
11+
STORAGE: texture3d
1112
generate_variant_forall:
1213
DTYPE:
1314
- VALUE: half

0 commit comments

Comments
 (0)