Skip to content

Commit eb4dc11

Browse files
author
ssjia
committed
Update on "[ET-VK] Introduce specialized implementation for per-row reduction"
Title says it all! This diff also adds support for argmin and argmax, but only for per-row reduction. Differential Revision: [D84716454](https://our.internmc.facebook.com/intern/diff/D84716454/) [ghstack-poisoned]
2 parents 488bd83 + 4bd4b28 commit eb4dc11

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

backends/vulkan/op_registry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ def register_softmax_op():
422422

423423
def get_dims_reduced(node: torch.fx.Node) -> Union[int, List[int]]:
424424
ndim = utils.ndim_of(node.args[0])
425+
assert ndim is not None
425426
dims_reduced = None
426427
if len(node.args) >= 1:
427428
dims_reduced = node.args[1]
@@ -438,6 +439,7 @@ def get_dims_reduced(node: torch.fx.Node) -> Union[int, List[int]]:
438439
if isinstance(dims_reduced, (list, tuple)) and len(dims_reduced) == 1:
439440
dims_reduced = dims_reduced[0]
440441

442+
assert isinstance(dims_reduced, (int, list, tuple))
441443
return utils.normalize_dims(dims_reduced, ndim)
442444

443445

@@ -456,6 +458,7 @@ def is_reduce_node_supported_by_per_row_impl(node: torch.fx.Node) -> bool:
456458
special case implementation.
457459
"""
458460
input_ndim = utils.ndim_of(node.args[0])
461+
assert input_ndim is not None
459462
dims_reduced = get_dims_reduced(node)
460463

461464
return dims_reduced == input_ndim - 1
@@ -505,7 +508,9 @@ def pick_storage_for_reduce(node: torch.fx.Node):
505508

506509
# For 2D reductions, the packed dimension cannot be one of the reduced dims
507510
if isinstance(dim_list, (list, tuple)) and len(dim_list) == 2:
511+
# pyre-ignore[6]
508512
reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim)
513+
# pyre-ignore[6]
509514
reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim)
510515

511516
possible_packed_dims = {0, 1, 2}
@@ -569,6 +574,7 @@ def register_2d_pool_op():
569574
def register_convolution_op():
570575
def check_conv_node(node: torch.fx.Node) -> bool:
571576
x = node.args[0]
577+
assert isinstance(x, torch.fx.Node)
572578
x_shape = x.meta["val"].size()
573579
# 4-D input implies 2D convolution
574580
if len(x_shape) == 4:

backends/vulkan/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def op_contains_bool_tensor(node: torch.fx.Node) -> bool:
293293
return True
294294

295295
for arg_node in node.args:
296+
# pyre-ignore[6]
296297
if is_tensor_node(arg_node) and tensor_node_is_bool(arg_node):
297298
return True
298299

@@ -756,6 +757,7 @@ def make_filtered_tensor_repset(
756757
CONTIGUOUS_BUFFER = TensorRepSet({VkMemoryLayout.TENSOR_WIDTH_PACKED}, set())
757758

758759
WIDTH_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_WIDTH_PACKED})
760+
HEIGHT_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_HEIGHT_PACKED})
759761
CHANNELS_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_CHANNELS_PACKED})
760762

761763
ANY_TEXTURE = TensorRepSet(set(), all_memory_layouts)

0 commit comments

Comments
 (0)