Skip to content

Commit 8da72c3

Browse files
authored
[ET-VK][ez] Enable max_pool2d.default (#14044)
max_pool2d_with_indices is already implemented; this diff enables max_pool2d as well by just re-using the same implementation. Differential Revision: [D81513446](https://our.internmc.facebook.com/intern/diff/D81513446/)
1 parent d25f6c4 commit 8da72c3

File tree

4 files changed

+26
-7
lines changed

4 files changed

+26
-7
lines changed

backends/vulkan/op_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
434434
@update_features(
435435
[
436436
exir_ops.edge.aten.avg_pool2d.default,
437+
exir_ops.edge.aten.max_pool2d.default,
437438
exir_ops.edge.aten.max_pool2d_with_indices.default,
438439
]
439440
)

backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ ${layout_declare_ubo(B, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "pad
2424

2525
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2626

27+
${layout_declare_spec_const(C, "int", "write_indices", "1")}
28+
2729
void main() {
2830
const ivec3 pos = ivec3(gl_GlobalInvocationID);
2931

@@ -55,5 +57,7 @@ void main() {
5557
}
5658

5759
imageStore(t_out, pos, out_texel);
58-
imageStore(t_idx, pos, idx_texel);
60+
if (write_indices > 0) {
61+
imageStore(t_idx, pos, idx_texel);
62+
}
5963
}

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,19 @@ void add_max_pool2d_node(
7676
const ValueRef dilation,
7777
const ValueRef ceil_mode,
7878
const ValueRef out) {
79-
const auto out_val = graph.get_value_list(out);
80-
const ValueRef out_tensor = out_val->at(0);
79+
ValueRef out_tensor = out;
80+
// Placeholder tensor to fill binding slot for indices tensor in case we are
81+
// computing max_pool2d instead of max_pool2d_with_indices.
82+
TmpTensor tmp_indices_tensor =
83+
TmpTensor(&graph, {}, graph.dtype_of(in), graph.storage_type_of(in));
84+
ValueRef indices_tensor = tmp_indices_tensor.vref;
85+
int32_t write_indices = 0;
86+
if (graph.val_is_value_list(out)) {
87+
const auto out_val = graph.get_value_list(out);
88+
out_tensor = out_val->at(0);
89+
indices_tensor = out_val->at(1);
90+
write_indices = 1;
91+
}
8192

8293
check_pool2d_args(graph, in, out_tensor);
8394

@@ -98,7 +109,7 @@ void add_max_pool2d_node(
98109
default_pick_global_wg_size,
99110
default_pick_local_wg_size,
100111
// Inputs and Outputs
101-
{{{out_val->at(0), out_val->at(1)}, vkapi::kWrite}, {in, vkapi::kRead}},
112+
{{{out_tensor, indices_tensor}, vkapi::kWrite}, {in, vkapi::kRead}},
102113
// Shader params buffers
103114
{
104115
graph.logical_limits_ubo(out_tensor),
@@ -108,7 +119,7 @@ void add_max_pool2d_node(
108119
// Push Constants
109120
{},
110121
// Specialization Constants
111-
{},
122+
{write_indices},
112123
// Resize Args
113124
{kernel_size, stride, padding, dilation, ceil_mode},
114125
// Resizing Logic
@@ -203,6 +214,7 @@ void avg_pool2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
203214
REGISTER_OPERATORS {
204215
VK_REGISTER_OP(aten.avg_pool2d.default, avg_pool2d);
205216
VK_REGISTER_OP(aten.max_pool2d_with_indices.default, max_pool2d);
217+
VK_REGISTER_OP(aten.max_pool2d.default, max_pool2d);
206218
}
207219

208220
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,13 @@ def get_avg_pool2d_inputs():
270270
return test_suite
271271

272272

273-
@register_test_suite("aten.max_pool2d_with_indices.default")
273+
@register_test_suite(
274+
["aten.max_pool2d_with_indices.default", "aten.max_pool2d.default"]
275+
)
274276
def get_max_pool2d_inputs():
275277
test_suite = VkTestSuite(
276278
[
277-
((S, M1, M2), [2, 2], [1, 1], [0, 0], [1, 1]),
279+
((1, 7, 89, 77), [2, 2], [1, 1], [0, 0], [1, 1]),
278280
]
279281
)
280282
return test_suite

0 commit comments

Comments
 (0)