Skip to content

Commit 7cef962

Browse files
author
ssjia
committed
[ET-VK] Introduce specialized implementation for per-row reduction
Pull Request resolved: #15161 Title says it all! This diff also adds support for argmin and argmax, but only for per-row reduction. ghstack-source-id: 317071315 @exported-using-ghexport Differential Revision: [D84716454](https://our.internmc.facebook.com/intern/diff/D84716454/)
1 parent 5997399 commit 7cef962

File tree

9 files changed

+615
-57
lines changed

9 files changed

+615
-57
lines changed

backends/vulkan/op_registry.py

Lines changed: 110 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,114 @@ def register_softmax_op():
420420
)
421421

422422

423+
def get_dims_reduced(node: torch.fx.Node) -> Union[int, List[int]]:
424+
ndim = utils.ndim_of(node.args[0])
425+
dims_reduced = None
426+
if len(node.args) >= 1:
427+
dims_reduced = node.args[1]
428+
429+
# If dim_list is None, return a list containing all the dims of the tensor
430+
if dims_reduced is None:
431+
dims_reduced = list(range(ndim))
432+
433+
# Special case for reducing tensors with shape [1, N] - this is equivalent to
434+
# reducing the last dim.
435+
if utils.is_unsqueezed_vector(node) and ndim == 2:
436+
dims_reduced = 1
437+
438+
if isinstance(dims_reduced, (list, tuple)) and len(dims_reduced) == 1:
439+
dims_reduced = dims_reduced[0]
440+
441+
return utils.normalize_dims(dims_reduced, ndim)
442+
443+
444+
def get_keepdim_setting(node: torch.fx.Node) -> bool:
445+
for arg in node.args:
446+
if isinstance(arg, bool):
447+
return arg
448+
449+
# Assume false by default
450+
return False
451+
452+
453+
def is_reduce_node_supported_by_per_row_impl(node: torch.fx.Node) -> bool:
454+
"""
455+
Checks if a reduction node is supported by the Vulkan backend's reduce per row
456+
special case implementation.
457+
"""
458+
input_ndim = utils.ndim_of(node.args[0])
459+
dims_reduced = get_dims_reduced(node)
460+
461+
return dims_reduced == input_ndim - 1
462+
463+
464+
def is_reduce_node_supported_by_general_impl(node: torch.fx.Node) -> bool:
465+
dims_reduced = get_dims_reduced(node)
466+
# Only 1D and 2D reductions are supported at the moment.
467+
if isinstance(dims_reduced, (list, tuple)) and len(dims_reduced) > 2:
468+
return False
469+
470+
keepdim = get_keepdim_setting(node)
471+
# keepdim = False is not supported yet for general implementation
472+
if isinstance(keepdim, bool) and not keepdim:
473+
return False
474+
475+
return True
476+
477+
478+
def is_reduce_node_supported(node: torch.fx.Node) -> bool:
479+
# 0-dim output unsupported at the moment
480+
if utils.ndim_of(node) == 0:
481+
return False
482+
483+
return is_reduce_node_supported_by_per_row_impl(
484+
node
485+
) or is_reduce_node_supported_by_general_impl(node)
486+
487+
488+
def pick_storage_for_reduce(node: torch.fx.Node):
489+
inputs_storage = utils.NO_STORAGE
490+
outputs_storage = utils.NO_STORAGE
491+
492+
ndim = utils.ndim_of(node.args[0])
493+
dim_list = node.args[1]
494+
495+
if is_reduce_node_supported_by_general_impl(node):
496+
inputs_storage = inputs_storage.make_union(utils.ANY_TEXTURE)
497+
outputs_storage = inputs_storage
498+
499+
# For 1D reductions of the last dim, a special reduce per row case is implemented
500+
# for buffer backed tensors.
501+
if is_reduce_node_supported_by_per_row_impl(node):
502+
inputs_storage = inputs_storage.make_union(utils.CONTIGUOUS_BUFFER)
503+
outputs_storage = inputs_storage
504+
return inputs_storage, outputs_storage
505+
506+
# For 2D reductions, the packed dimension cannot be one of the reduced dims
507+
if isinstance(dim_list, (list, tuple)) and len(dim_list) == 2:
508+
reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim)
509+
reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim)
510+
511+
possible_packed_dims = {0, 1, 2}
512+
possible_packed_dims.discard(reduce_dim1_whcn)
513+
possible_packed_dims.discard(reduce_dim2_whcn)
514+
515+
packed_dim = possible_packed_dims.pop()
516+
assert packed_dim in [0, 1, 2]
517+
518+
if packed_dim == 0:
519+
inputs_storage = utils.WIDTH_PACKED_TEXTURE
520+
outputs_storage = utils.WIDTH_PACKED_TEXTURE
521+
elif packed_dim == 1:
522+
inputs_storage = utils.HEIGHT_PACKED_TEXTURE
523+
outputs_storage = utils.HEIGHT_PACKED_TEXTURE
524+
else:
525+
inputs_storage = utils.CHANNELS_PACKED_TEXTURE
526+
outputs_storage = utils.CHANNELS_PACKED_TEXTURE
527+
528+
return inputs_storage, outputs_storage
529+
530+
423531
@update_features(
424532
[
425533
exir_ops.edge.aten.mean.dim,
@@ -429,66 +537,12 @@ def register_softmax_op():
429537
]
430538
)
431539
def register_reduce_op():
432-
def check_reduce_node(node: torch.fx.Node) -> bool:
433-
# Only one argument implies that the reduction is over the entire tensor, which
434-
# is not supported yet.
435-
if len(node.args) == 1:
436-
return False
437-
438-
dim_list = node.args[1]
439-
# Only 1D and 2D reductions are supported at the moment.
440-
if isinstance(dim_list, list) and len(dim_list) > 2:
441-
return False
442-
443-
def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
444-
for arg in node.args:
445-
if isinstance(arg, bool):
446-
return arg
447-
448-
# Assume false by default
449-
return False
450-
451-
keepdim = try_find_keepdim_arg(node)
452-
if isinstance(keepdim, bool) and not keepdim:
453-
return False
454-
455-
return True
456-
457-
def pick_io_storage_for_reduce(node: torch.fx.Node):
458-
inputs_storage = utils.ANY_TEXTURE
459-
outputs_storage = utils.ANY_TEXTURE
460-
461-
input_tensor = node.args[0]
462-
ndim = input_tensor.meta["val"].ndim
463-
dim_list = node.args[1]
464-
if isinstance(dim_list, list) and len(dim_list) == 2:
465-
reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim)
466-
reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim)
467-
468-
possible_packed_dims = {0, 1, 2}
469-
possible_packed_dims.discard(reduce_dim1_whcn)
470-
possible_packed_dims.discard(reduce_dim2_whcn)
471-
472-
packed_dim = possible_packed_dims.pop()
473-
assert packed_dim in [0, 1, 2]
474-
475-
if packed_dim == 0:
476-
inputs_storage = utils.WIDTH_PACKED_TEXTURE
477-
outputs_storage = utils.WIDTH_PACKED_TEXTURE
478-
elif packed_dim == 1:
479-
inputs_storage = utils.HEIGHT_PACKED_TEXTURE
480-
outputs_storage = utils.HEIGHT_PACKED_TEXTURE
481-
else:
482-
inputs_storage = utils.CHANNELS_PACKED_TEXTURE
483-
outputs_storage = utils.CHANNELS_PACKED_TEXTURE
484-
485-
return inputs_storage, outputs_storage
486540

487541
return OpFeatures(
488542
inputs_storage=utils.ANY_TEXTURE,
489543
supports_resize=True,
490-
are_node_inputs_supported_fn=check_reduce_node,
491-
pick_io_storage_fn=pick_io_storage_for_reduce,
544+
are_node_inputs_supported_fn=is_reduce_node_supported,
545+
pick_io_storage_fn=pick_storage_for_reduce,
492546
)
493547

494548

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef REDUCE_OP_DEFS_GLSLH
10+
#define REDUCE_OP_DEFS_GLSLH
11+
12+
struct Accum {
13+
T val;
14+
uint idx;
15+
uint count;
16+
};
17+
18+
void init_accum(out Accum accum, T val, uint idx) {
19+
accum.val = val;
20+
accum.idx = idx;
21+
accum.count = 1;
22+
}
23+
24+
void init_accum_zero(out Accum accum) {
25+
accum.val = T(0);
26+
accum.idx = 0;
27+
accum.count = 0;
28+
}
29+
30+
// Sum / Mean
31+
32+
void update_accum_sum(inout Accum accum, T val, uint idx) {
33+
accum.val += val;
34+
accum.count += 1;
35+
}
36+
37+
void merge_accum_sum(inout Accum accum, const Accum other) {
38+
accum.val += other.val;
39+
accum.count += other.count;
40+
}
41+
42+
void postprocess_accum_mean(inout Accum accum) {
43+
accum.val /= T(accum.count);
44+
}
45+
46+
// Amax (maximum value)
47+
48+
void update_accum_amax(inout Accum accum, T val, uint idx) {
49+
if (val > accum.val) {
50+
accum.val = val;
51+
accum.idx = idx;
52+
}
53+
// For equivalence, select the lower index
54+
if (val == accum.val && idx < accum.idx) {
55+
accum.idx = idx;
56+
}
57+
}
58+
59+
void merge_accum_amax(inout Accum accum, const Accum other) {
60+
if (other.val > accum.val) {
61+
accum.val = other.val;
62+
accum.idx = other.idx;
63+
}
64+
// For equivalence, select the lower index
65+
if (other.val == accum.val && other.idx < accum.idx) {
66+
accum.idx = other.idx;
67+
}
68+
}
69+
70+
// Amin (minimum value)
71+
72+
void update_accum_amin(inout Accum accum, T val, uint idx) {
73+
if (val < accum.val) {
74+
accum.val = val;
75+
accum.idx = idx;
76+
}
77+
// For equivalence, select the lower index
78+
if (val == accum.val && idx < accum.idx) {
79+
accum.idx = idx;
80+
}
81+
}
82+
83+
void merge_accum_amin(inout Accum accum, const Accum other) {
84+
if (other.count > 0 && (accum.count == 0 || other.val < accum.val)) {
85+
accum.val = other.val;
86+
accum.idx = other.idx;
87+
}
88+
// For equivalence, select the lower index
89+
if (other.val == accum.val && other.idx < accum.idx) {
90+
accum.idx = other.idx;
91+
}
92+
}
93+
94+
#endif // REDUCE_OP_DEFS_GLSLH

0 commit comments

Comments
 (0)