Skip to content

Commit f78bf36

Browse files
author
ssjia
committed
[ET-VK] Introduce specialized implementation for per-row reduction
Pull Request resolved: #15161 Title says it all! This diff also adds support for argmin and argmax, but only for per-row reduction. ghstack-source-id: 317645474 @exported-using-ghexport Differential Revision: [D84716454](https://our.internmc.facebook.com/intern/diff/D84716454/)
1 parent d9463bb commit f78bf36

File tree

9 files changed

+576
-58
lines changed

9 files changed

+576
-58
lines changed

backends/vulkan/op_registry.py

Lines changed: 116 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -420,75 +420,133 @@ def register_softmax_op():
420420
)
421421

422422

423-
@update_features(
424-
[
425-
exir_ops.edge.aten.mean.dim,
426-
exir_ops.edge.aten.sum.dim_IntList,
427-
exir_ops.edge.aten.amax.default,
428-
exir_ops.edge.aten.amin.default,
429-
]
430-
)
431-
def register_reduce_op():
432-
def check_reduce_node(node: torch.fx.Node) -> bool:
433-
# Only one argument implies that the reduction is over the entire tensor, which
434-
# is not supported yet.
435-
if len(node.args) == 1:
436-
return False
423+
def get_dims_reduced(node: torch.fx.Node) -> Union[int, List[int]]:
424+
ndim = utils.ndim_of(node.args[0])
425+
assert ndim is not None
426+
dims_reduced = None
427+
if len(node.args) >= 1:
428+
dims_reduced = node.args[1]
437429

438-
dim_list = node.args[1]
439-
# Only 1D and 2D reductions are supported at the moment.
440-
if isinstance(dim_list, list) and len(dim_list) > 2:
441-
return False
430+
# If dim_list is None, return a list containing all the dims of the tensor
431+
if dims_reduced is None:
432+
dims_reduced = list(range(ndim))
442433

443-
def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
444-
for arg in node.args:
445-
if isinstance(arg, bool):
446-
return arg
434+
# Special case for reducing tensors with shape [1, N] - this is equivalent to
435+
# reducing the last dim.
436+
if utils.is_unsqueezed_vector(node) and ndim == 2:
437+
dims_reduced = 1
447438

448-
# Assume false by default
449-
return False
439+
if isinstance(dims_reduced, (list, tuple)) and len(dims_reduced) == 1:
440+
dims_reduced = dims_reduced[0]
450441

451-
keepdim = try_find_keepdim_arg(node)
452-
if isinstance(keepdim, bool) and not keepdim:
453-
return False
442+
assert isinstance(dims_reduced, (int, list, tuple))
443+
return utils.normalize_dims(dims_reduced, ndim)
454444

455-
return True
456445

457-
def pick_io_storage_for_reduce(node: torch.fx.Node):
458-
inputs_storage = utils.ANY_TEXTURE
459-
outputs_storage = utils.ANY_TEXTURE
460-
461-
input_tensor = node.args[0]
462-
ndim = input_tensor.meta["val"].ndim
463-
dim_list = node.args[1]
464-
if isinstance(dim_list, list) and len(dim_list) == 2:
465-
reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim)
466-
reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim)
467-
468-
possible_packed_dims = {0, 1, 2}
469-
possible_packed_dims.discard(reduce_dim1_whcn)
470-
possible_packed_dims.discard(reduce_dim2_whcn)
471-
472-
packed_dim = possible_packed_dims.pop()
473-
assert packed_dim in [0, 1, 2]
474-
475-
if packed_dim == 0:
476-
inputs_storage = utils.WIDTH_PACKED_TEXTURE
477-
outputs_storage = utils.WIDTH_PACKED_TEXTURE
478-
elif packed_dim == 1:
479-
inputs_storage = utils.HEIGHT_PACKED_TEXTURE
480-
outputs_storage = utils.HEIGHT_PACKED_TEXTURE
481-
else:
482-
inputs_storage = utils.CHANNELS_PACKED_TEXTURE
483-
outputs_storage = utils.CHANNELS_PACKED_TEXTURE
446+
def get_keepdim_setting(node: torch.fx.Node) -> bool:
447+
for arg in node.args:
448+
if isinstance(arg, bool):
449+
return arg
450+
451+
# Assume false by default
452+
return False
453+
454+
455+
def is_reduce_node_supported_by_per_row_impl(node: torch.fx.Node) -> bool:
456+
"""
457+
Checks if a reduction node is supported by the Vulkan backend's reduce per row
458+
special case implementation.
459+
"""
460+
input_ndim = utils.ndim_of(node.args[0])
461+
assert input_ndim is not None
462+
dims_reduced = get_dims_reduced(node)
463+
464+
return dims_reduced == input_ndim - 1
465+
466+
467+
def is_reduce_node_supported_by_general_impl(node: torch.fx.Node) -> bool:
468+
dims_reduced = get_dims_reduced(node)
469+
# Only 1D and 2D reductions are supported at the moment.
470+
if isinstance(dims_reduced, (list, tuple)) and len(dims_reduced) > 2:
471+
return False
472+
473+
keepdim = get_keepdim_setting(node)
474+
# keepdim = False is not supported yet for general implementation
475+
if isinstance(keepdim, bool) and not keepdim:
476+
return False
477+
478+
return True
479+
480+
481+
def is_reduce_node_supported(node: torch.fx.Node) -> bool:
482+
# 0-dim output unsupported at the moment
483+
if utils.ndim_of(node) == 0:
484+
return False
485+
486+
return is_reduce_node_supported_by_per_row_impl(
487+
node
488+
) or is_reduce_node_supported_by_general_impl(node)
484489

490+
491+
def pick_storage_for_reduce(node: torch.fx.Node):
492+
inputs_storage = utils.NO_STORAGE
493+
outputs_storage = utils.NO_STORAGE
494+
495+
ndim = utils.ndim_of(node.args[0])
496+
dim_list = node.args[1]
497+
498+
if is_reduce_node_supported_by_general_impl(node):
499+
inputs_storage = inputs_storage.make_union(utils.ANY_TEXTURE)
500+
outputs_storage = inputs_storage
501+
502+
# For 1D reductions of the last dim, a special reduce per row case is implemented
503+
# for buffer backed tensors.
504+
if is_reduce_node_supported_by_per_row_impl(node):
505+
inputs_storage = inputs_storage.make_union(utils.CONTIGUOUS_BUFFER)
506+
outputs_storage = inputs_storage
485507
return inputs_storage, outputs_storage
486508

509+
# For 2D reductions, the packed dimension cannot be one of the reduced dims
510+
if isinstance(dim_list, (list, tuple)) and len(dim_list) == 2:
511+
# pyre-ignore[6]
512+
reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim)
513+
# pyre-ignore[6]
514+
reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim)
515+
516+
possible_packed_dims = {0, 1, 2}
517+
possible_packed_dims.discard(reduce_dim1_whcn)
518+
possible_packed_dims.discard(reduce_dim2_whcn)
519+
520+
packed_dim = possible_packed_dims.pop()
521+
assert packed_dim in [0, 1, 2]
522+
523+
if packed_dim == 0:
524+
inputs_storage = utils.WIDTH_PACKED_TEXTURE
525+
outputs_storage = utils.WIDTH_PACKED_TEXTURE
526+
elif packed_dim == 1:
527+
inputs_storage = utils.HEIGHT_PACKED_TEXTURE
528+
outputs_storage = utils.HEIGHT_PACKED_TEXTURE
529+
else:
530+
inputs_storage = utils.CHANNELS_PACKED_TEXTURE
531+
outputs_storage = utils.CHANNELS_PACKED_TEXTURE
532+
533+
return inputs_storage, outputs_storage
534+
535+
536+
@update_features(
537+
[
538+
exir_ops.edge.aten.mean.dim,
539+
exir_ops.edge.aten.sum.dim_IntList,
540+
exir_ops.edge.aten.amax.default,
541+
exir_ops.edge.aten.amin.default,
542+
]
543+
)
544+
def register_reduce_op():
487545
return OpFeatures(
488546
inputs_storage=utils.ANY_TEXTURE,
489547
supports_resize=True,
490-
are_node_inputs_supported_fn=check_reduce_node,
491-
pick_io_storage_fn=pick_io_storage_for_reduce,
548+
are_node_inputs_supported_fn=is_reduce_node_supported,
549+
pick_io_storage_fn=pick_storage_for_reduce,
492550
)
493551

494552

@@ -515,6 +573,7 @@ def register_2d_pool_op():
515573
def register_convolution_op():
516574
def check_conv_node(node: torch.fx.Node) -> bool:
517575
x = node.args[0]
576+
assert isinstance(x, torch.fx.Node)
518577
x_shape = x.meta["val"].size()
519578
# 4-D input implies 2D convolution
520579
if len(x_shape) == 4:
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef REDUCE_OP_DEFS_GLSLH
10+
#define REDUCE_OP_DEFS_GLSLH
11+
12+
struct Accum {
13+
T val;
14+
uint idx;
15+
uint count;
16+
};
17+
18+
void init_accum(out Accum accum, T val, uint idx) {
19+
accum.val = val;
20+
accum.idx = idx;
21+
accum.count = 1;
22+
}
23+
24+
void init_accum_zero(out Accum accum) {
25+
accum.val = T(0);
26+
accum.idx = 0;
27+
accum.count = 0;
28+
}
29+
30+
// Sum / Mean
31+
32+
void update_accum_sum(inout Accum accum, T val, uint idx) {
33+
accum.val += val;
34+
accum.count += 1;
35+
}
36+
37+
void merge_accum_sum(inout Accum accum, const Accum other) {
38+
accum.val += other.val;
39+
accum.count += other.count;
40+
}
41+
42+
void postprocess_accum_mean(inout Accum accum) {
43+
accum.val /= T(accum.count);
44+
}
45+
46+
// Amax (maximum value)
47+
48+
void update_accum_amax(inout Accum accum, T val, uint idx) {
49+
if (val > accum.val) {
50+
accum.val = val;
51+
accum.idx = idx;
52+
}
53+
// For equivalence, select the lower index
54+
if (val == accum.val && idx < accum.idx) {
55+
accum.idx = idx;
56+
}
57+
}
58+
59+
void merge_accum_amax(inout Accum accum, const Accum other) {
60+
if (other.val > accum.val) {
61+
accum.val = other.val;
62+
accum.idx = other.idx;
63+
}
64+
// For equivalence, select the lower index
65+
if (other.val == accum.val && other.idx < accum.idx) {
66+
accum.idx = other.idx;
67+
}
68+
}
69+
70+
// Amin (minimum value)
71+
72+
void update_accum_amin(inout Accum accum, T val, uint idx) {
73+
if (val < accum.val) {
74+
accum.val = val;
75+
accum.idx = idx;
76+
}
77+
// For equivalence, select the lower index
78+
if (val == accum.val && idx < accum.idx) {
79+
accum.idx = idx;
80+
}
81+
}
82+
83+
void merge_accum_amin(inout Accum accum, const Accum other) {
84+
if (other.count > 0 && (accum.count == 0 || other.val < accum.val)) {
85+
accum.val = other.val;
86+
accum.idx = other.idx;
87+
}
88+
// For equivalence, select the lower index
89+
if (other.val == accum.val && other.idx < accum.idx) {
90+
accum.idx = other.idx;
91+
}
92+
}
93+
94+
#endif // REDUCE_OP_DEFS_GLSLH

0 commit comments

Comments
 (0)