Skip to content

Commit 6c1f9fa

Browse files
SS-JIAssjia
andauthored
[ET-VK][ez] Fix partitioner logic of finding keepdim arg of reduce ops (#13598)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #13597 * #13596 * #13595 * #13594 * #13593 * #13600 * #13599 * __->__ #13598 Title says it all. For reduce ops, their signature are not all alike so some extra legwork needs to be done to identify specific arguments that need to be checked. Also included a small update to partitioner logging to improve debuggability. Differential Revision: [D80741737](https://our.internmc.facebook.com/intern/diff/D80741737/) --------- Co-authored-by: ssjia <[email protected]>
1 parent 523aa45 commit 6c1f9fa

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

backends/vulkan/op_registry.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -397,14 +397,17 @@ def check_reduce_node(node: torch.fx.Node) -> bool:
397397
# If we can't get memory layout information, we'll assume the dims aren't packed
398398
pass
399399

400-
keepdim = node.args[2]
401-
if isinstance(keepdim, bool) and not keepdim:
400+
def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
401+
for arg in node.args:
402+
if isinstance(arg, bool):
403+
return arg
404+
405+
# Assume false by default
402406
return False
403407

404-
if len(node.args) > 2:
405-
keepdim = node.args[2]
406-
if isinstance(keepdim, bool) and not keepdim:
407-
return False
408+
keepdim = try_find_keepdim_arg(node)
409+
if isinstance(keepdim, bool) and not keepdim:
410+
return False
408411

409412
return True
410413

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, boo
204204
def log_skip(self, node: torch.fx.Node, reason: str) -> None:
205205
if node.op == "call_function":
206206
logger.info(
207-
f"[Vulkan Partitioner] Due to [{reason}], skipping {node.format_node()}"
207+
f"[Vulkan Partitioner] Due to [{reason}], skipping {utils.node_io_str(node)}"
208208
)
209209

210210
def is_node_supported(

backends/vulkan/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,8 @@ def get_node_val_str(node: torch.fx.Node) -> str:
10591059
assert isinstance(node.meta["val"], (list, tuple))
10601060
return f"[{', '.join(get_tensor_val_str(t) for t in node.meta['val'])}]"
10611061
else:
1062+
if "val" not in node.meta:
1063+
return str(node)
10621064
return str(node.meta["val"])
10631065

10641066

0 commit comments

Comments
 (0)