Skip to content

Commit 357d573

Browse files
authored
[ET-VK] Introduce specialized implementation for per-row reduction (#15587)
Title says it all! This diff also adds support for argmin and argmax, but only for per-row reduction. Differential Revision: [D84716454](https://our.internmc.facebook.com/intern/diff/D84716454/)
1 parent 62f5703 commit 357d573

File tree

11 files changed

+624
-60
lines changed

11 files changed

+624
-60
lines changed

backends/vulkan/op_registry.py

Lines changed: 114 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -420,75 +420,131 @@ def register_softmax_op():
420420
)
421421

422422

423-
@update_features(
424-
[
425-
exir_ops.edge.aten.mean.dim,
426-
exir_ops.edge.aten.sum.dim_IntList,
427-
exir_ops.edge.aten.amax.default,
428-
exir_ops.edge.aten.amin.default,
429-
]
430-
)
431-
def register_reduce_op():
432-
def check_reduce_node(node: torch.fx.Node) -> bool:
433-
# Only one argument implies that the reduction is over the entire tensor, which
434-
# is not supported yet.
435-
if len(node.args) == 1:
436-
return False
423+
def get_dims_reduced(node: torch.fx.Node) -> Union[int, List[int]]:
424+
ndim = utils.ndim_of(node.args[0])
425+
assert ndim is not None
426+
dims_reduced = None
427+
if len(node.args) >= 2:
428+
dims_reduced = node.args[1]
437429

438-
dim_list = node.args[1]
439-
# Only 1D and 2D reductions are supported at the moment.
440-
if isinstance(dim_list, list) and len(dim_list) > 2:
441-
return False
430+
# If dim_list is None, return a list containing all the dims of the tensor
431+
if dims_reduced is None:
432+
dims_reduced = list(range(ndim))
442433

443-
def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
444-
for arg in node.args:
445-
if isinstance(arg, bool):
446-
return arg
434+
# Special case for reducing tensors with shape [1, N] - this is equivalent to
435+
# reducing the last dim.
436+
if utils.is_unsqueezed_vector(node) and ndim == 2:
437+
dims_reduced = 1
447438

448-
# Assume false by default
449-
return False
439+
if isinstance(dims_reduced, (list, tuple)) and len(dims_reduced) == 1:
440+
dims_reduced = dims_reduced[0]
450441

451-
keepdim = try_find_keepdim_arg(node)
452-
if isinstance(keepdim, bool) and not keepdim:
453-
return False
442+
assert isinstance(dims_reduced, (int, list, tuple))
443+
return utils.normalize_dims(dims_reduced, ndim)
454444

455-
return True
456445

457-
def pick_io_storage_for_reduce(node: torch.fx.Node):
458-
inputs_storage = utils.ANY_TEXTURE
459-
outputs_storage = utils.ANY_TEXTURE
460-
461-
input_tensor = node.args[0]
462-
ndim = input_tensor.meta["val"].ndim
463-
dim_list = node.args[1]
464-
if isinstance(dim_list, list) and len(dim_list) == 2:
465-
reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim)
466-
reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim)
467-
468-
possible_packed_dims = {0, 1, 2}
469-
possible_packed_dims.discard(reduce_dim1_whcn)
470-
possible_packed_dims.discard(reduce_dim2_whcn)
471-
472-
packed_dim = possible_packed_dims.pop()
473-
assert packed_dim in [0, 1, 2]
474-
475-
if packed_dim == 0:
476-
inputs_storage = utils.WIDTH_PACKED_TEXTURE
477-
outputs_storage = utils.WIDTH_PACKED_TEXTURE
478-
elif packed_dim == 1:
479-
inputs_storage = utils.HEIGHT_PACKED_TEXTURE
480-
outputs_storage = utils.HEIGHT_PACKED_TEXTURE
481-
else:
482-
inputs_storage = utils.CHANNELS_PACKED_TEXTURE
483-
outputs_storage = utils.CHANNELS_PACKED_TEXTURE
446+
def get_keepdim_setting(node: torch.fx.Node) -> bool:
447+
for arg in node.args:
448+
if isinstance(arg, bool):
449+
return arg
450+
451+
# Assume false by default
452+
return False
453+
454+
455+
def is_reduce_node_supported_by_per_row_impl(node: torch.fx.Node) -> bool:
456+
"""
457+
Checks if a reduction node is supported by the Vulkan backend's reduce per row
458+
special case implementation.
459+
"""
460+
input_ndim = utils.ndim_of(node.args[0])
461+
assert input_ndim is not None
462+
dims_reduced = get_dims_reduced(node)
463+
464+
return dims_reduced == input_ndim - 1
465+
466+
467+
def is_reduce_node_supported_by_general_impl(node: torch.fx.Node) -> bool:
468+
dims_reduced = get_dims_reduced(node)
469+
# Only 1D and 2D reductions are supported at the moment.
470+
if isinstance(dims_reduced, (list, tuple)) and len(dims_reduced) > 2:
471+
return False
472+
473+
keepdim = get_keepdim_setting(node)
474+
# keepdim = False is not supported yet for general implementation
475+
if isinstance(keepdim, bool) and not keepdim:
476+
return False
477+
478+
return True
479+
480+
481+
def is_reduce_node_supported(node: torch.fx.Node) -> bool:
482+
return is_reduce_node_supported_by_per_row_impl(
483+
node
484+
) or is_reduce_node_supported_by_general_impl(node)
485+
484486

487+
def pick_storage_for_reduce(node: torch.fx.Node):
488+
inputs_storage = utils.NO_STORAGE
489+
outputs_storage = utils.NO_STORAGE
490+
491+
ndim = utils.ndim_of(node.args[0])
492+
dim_list = get_dims_reduced(node)
493+
494+
if is_reduce_node_supported_by_general_impl(node):
495+
inputs_storage = inputs_storage.make_union(utils.ANY_TEXTURE)
496+
outputs_storage = inputs_storage
497+
498+
# For 1D reductions of the last dim, a special reduce per row case is implemented
499+
# for buffer backed tensors.
500+
if is_reduce_node_supported_by_per_row_impl(node):
501+
inputs_storage = inputs_storage.make_union(utils.CONTIGUOUS_BUFFER)
502+
outputs_storage = inputs_storage
485503
return inputs_storage, outputs_storage
486504

505+
# For 2D reductions, the packed dimension cannot be one of the reduced dims
506+
if isinstance(dim_list, (list, tuple)) and len(dim_list) == 2:
507+
# pyre-ignore[6]
508+
reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim)
509+
# pyre-ignore[6]
510+
reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim)
511+
512+
possible_packed_dims = {0, 1, 2}
513+
possible_packed_dims.discard(reduce_dim1_whcn)
514+
possible_packed_dims.discard(reduce_dim2_whcn)
515+
516+
packed_dim = possible_packed_dims.pop()
517+
assert packed_dim in [0, 1, 2]
518+
519+
if packed_dim == 0:
520+
inputs_storage = utils.WIDTH_PACKED_TEXTURE
521+
outputs_storage = utils.WIDTH_PACKED_TEXTURE
522+
elif packed_dim == 1:
523+
inputs_storage = utils.HEIGHT_PACKED_TEXTURE
524+
outputs_storage = utils.HEIGHT_PACKED_TEXTURE
525+
else:
526+
inputs_storage = utils.CHANNELS_PACKED_TEXTURE
527+
outputs_storage = utils.CHANNELS_PACKED_TEXTURE
528+
529+
return inputs_storage, outputs_storage
530+
531+
532+
@update_features(
533+
[
534+
exir_ops.edge.aten.mean.dim,
535+
exir_ops.edge.aten.sum.dim_IntList,
536+
exir_ops.edge.aten.amax.default,
537+
exir_ops.edge.aten.amin.default,
538+
exir_ops.edge.aten.argmax.default,
539+
exir_ops.edge.aten.argmin.default,
540+
]
541+
)
542+
def register_reduce_op():
487543
return OpFeatures(
488544
inputs_storage=utils.ANY_TEXTURE,
489545
supports_resize=True,
490-
are_node_inputs_supported_fn=check_reduce_node,
491-
pick_io_storage_fn=pick_io_storage_for_reduce,
546+
are_node_inputs_supported_fn=is_reduce_node_supported,
547+
pick_io_storage_fn=pick_storage_for_reduce,
492548
)
493549

494550

@@ -515,6 +571,7 @@ def register_2d_pool_op():
515571
def register_convolution_op():
516572
def check_conv_node(node: torch.fx.Node) -> bool:
517573
x = node.args[0]
574+
assert isinstance(x, torch.fx.Node)
518575
x_shape = x.meta["val"].size()
519576
# 4-D input implies 2D convolution
520577
if len(x_shape) == 4:

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,14 @@ def texel_component_type(dtype: str) -> str:
233233
raise AssertionError(f"Invalid vec4 type: {vec4_type}")
234234

235235

236+
def accum_vec_type(dtype: str) -> str:
237+
return texel_type(dtype)
238+
239+
240+
def accum_scalar_type(dtype: str) -> str:
241+
return texel_component_type(dtype)
242+
243+
236244
def texel_load_type(dtype: str, storage_type: str) -> str:
237245
if storage_type.lower() == "buffer":
238246
return buffer_gvec_type(dtype, 4)
@@ -455,6 +463,8 @@ def define_required_extensions(dtypes: Union[str, List[str]]):
455463
"buffer_gvec_type": buffer_gvec_type,
456464
"texel_type": texel_type,
457465
"gvec_type": gvec_type,
466+
"accum_vec_type": accum_vec_type,
467+
"accum_scalar_type": accum_scalar_type,
458468
"texel_component_type": texel_component_type,
459469
"texel_load_type": texel_load_type,
460470
"texel_load_component_type": texel_load_component_type,
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef CONVERT_GLSLH
10+
#define CONVERT_GLSLH
11+
12+
// Scalar Conversions
13+
14+
#ifdef T
15+
16+
#if T == float16_t
17+
18+
#define convert_to_T(x) T(clamp(x, -65504, 65504));
19+
20+
#else
21+
22+
#define convert_to_T(x) T(x);
23+
24+
#endif // T == float16_t
25+
26+
#endif // T
27+
28+
#endif // CONVERT_GLSLH
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef REDUCE_OP_DEFS_GLSLH
10+
#define REDUCE_OP_DEFS_GLSLH
11+
12+
struct Accum {
13+
ACCUM_T val;
14+
uint idx;
15+
uint count;
16+
};
17+
18+
void init_accum(out Accum accum, T val, uint idx) {
19+
accum.val = ACCUM_T(val);
20+
accum.idx = idx;
21+
accum.count = 1;
22+
}
23+
24+
void init_accum_zero(out Accum accum) {
25+
accum.val = T(0);
26+
accum.idx = 0;
27+
accum.count = 0;
28+
}
29+
30+
// Sum / Mean
31+
32+
void update_accum_sum(inout Accum accum, T val, uint idx) {
33+
accum.val += ACCUM_T(val);
34+
accum.count += 1;
35+
}
36+
37+
void merge_accum_sum(inout Accum accum, const Accum other) {
38+
accum.val += other.val;
39+
accum.count += other.count;
40+
}
41+
42+
void postprocess_accum_mean(inout Accum accum) {
43+
accum.val /= T(accum.count);
44+
}
45+
46+
// Amax (maximum value)
47+
48+
void update_accum_amax(inout Accum accum, T in_val, uint idx) {
49+
ACCUM_T val = ACCUM_T(in_val);
50+
if (val > accum.val) {
51+
accum.val = val;
52+
accum.idx = idx;
53+
}
54+
// For equivalence, select the lower index
55+
if (val == accum.val && idx < accum.idx) {
56+
accum.idx = idx;
57+
}
58+
}
59+
60+
void merge_accum_amax(inout Accum accum, const Accum other) {
61+
if (other.val > accum.val) {
62+
accum.val = other.val;
63+
accum.idx = other.idx;
64+
}
65+
// For equivalence, select the lower index
66+
if (other.val == accum.val && other.idx < accum.idx) {
67+
accum.idx = other.idx;
68+
}
69+
}
70+
71+
// Amin (minimum value)
72+
73+
void update_accum_amin(inout Accum accum, T in_val, uint idx) {
74+
ACCUM_T val = ACCUM_T(in_val);
75+
if (val < accum.val) {
76+
accum.val = val;
77+
accum.idx = idx;
78+
}
79+
// For equivalence, select the lower index
80+
if (val == accum.val && idx < accum.idx) {
81+
accum.idx = idx;
82+
}
83+
}
84+
85+
void merge_accum_amin(inout Accum accum, const Accum other) {
86+
if (other.count > 0 && (accum.count == 0 || other.val < accum.val)) {
87+
accum.val = other.val;
88+
accum.idx = other.idx;
89+
}
90+
// For equivalence, select the lower index
91+
if (other.val == accum.val && other.idx < accum.idx) {
92+
accum.idx = other.idx;
93+
}
94+
}
95+
96+
#endif // REDUCE_OP_DEFS_GLSLH

0 commit comments

Comments
 (0)