diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index b7f8f3de955..a6cc59e26f0 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -397,14 +397,17 @@ def check_reduce_node(node: torch.fx.Node) -> bool: # If we can't get memory layout information, we'll assume the dims aren't packed pass - keepdim = node.args[2] - if isinstance(keepdim, bool) and not keepdim: + def try_find_keepdim_arg(node: torch.fx.Node) -> bool: + for arg in node.args: + if isinstance(arg, bool): + return arg + + # Assume false by default return False - if len(node.args) > 2: - keepdim = node.args[2] - if isinstance(keepdim, bool) and not keepdim: - return False + keepdim = try_find_keepdim_arg(node) + if isinstance(keepdim, bool) and not keepdim: + return False return True diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 1b5ff0a44e4..04a1a500b64 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -204,7 +204,7 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, boo def log_skip(self, node: torch.fx.Node, reason: str) -> None: if node.op == "call_function": logger.info( - f"[Vulkan Partitioner] Due to [{reason}], skipping {node.format_node()}" + f"[Vulkan Partitioner] Due to [{reason}], skipping {utils.node_io_str(node)}" ) def is_node_supported( diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 1765f0b5e1c..bc03860ed3f 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -1059,6 +1059,8 @@ def get_node_val_str(node: torch.fx.Node) -> str: assert isinstance(node.meta["val"], (list, tuple)) return f"[{', '.join(get_tensor_val_str(t) for t in node.meta['val'])}]" else: + if "val" not in node.meta: + return str(node) return str(node.meta["val"])