From 46b708ffca9096a19fd7fb8efa9057962be1687a Mon Sep 17 00:00:00 2001 From: Zuby A Date: Tue, 25 Mar 2025 13:47:00 -0700 Subject: [PATCH 1/3] Add mixed dtype check for XNNPACK partitioner --- .../partition/config/xnnpack_config.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py index 20018610fce..6806e90ac34 100644 --- a/backends/xnnpack/partition/config/xnnpack_config.py +++ b/backends/xnnpack/partition/config/xnnpack_config.py @@ -144,9 +144,10 @@ def check_common_constraints( return True def _check_inputs_are_valid_dtypes(self, node, valid_dtypes): - # Check inputs are valid dtypes + # Check inputs are valid and have the same dtypes # Gather all args which are nodes args_to_check = [] + reference_dtype = None for arg in node.args: if isinstance(arg, list) or isinstance(arg, tuple): for item in arg: @@ -174,11 +175,19 @@ def _check_inputs_are_valid_dtypes(self, node, valid_dtypes): if arg_val.dtype not in valid_dtypes: return False + # Check for mixed dtypes + if reference_dtype is None: + reference_dtype = arg_val.dtype + elif arg_val.dtype != reference_dtype: + return False + return True def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): - # Check outputs are valid dtype + # Check outputs are valid and have the same dtypes node_val = node.meta.get("val", None) + reference_dtype = None + if node_val is None: return True @@ -192,6 +201,12 @@ def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): if val.dtype not in valid_dtypes: return False + # Check for mixed dtypes + if reference_dtype is None: + reference_dtype = val.dtype + elif val.dtype != reference_dtype: + return False + return True def _check_node_has_valid_dtype(self, node): From 2c10b33037885712fe11637d19e82f01dcf4809a Mon Sep 17 00:00:00 2001 From: Zuby Afzal Date: Wed, 26 Mar 2025 15:11:16 -0700 Subject: [PATCH 2/3] Add why log for mixed input dtypes; remove mixed dtype check from output --- .../partition/config/xnnpack_config.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py index 6806e90ac34..c30098d8805 100644 --- a/backends/xnnpack/partition/config/xnnpack_config.py +++ b/backends/xnnpack/partition/config/xnnpack_config.py @@ -176,17 +176,22 @@ def _check_inputs_are_valid_dtypes(self, node, valid_dtypes): return False # Check for mixed dtypes - if reference_dtype is None: - reference_dtype = arg_val.dtype - elif arg_val.dtype != reference_dtype: + reference_dtype = reference_dtype or arg_val.dtype + if arg_val.dtype != reference_dtype: + why( + node, + reason=( + f"{node.target} does not support mixed input dtypes. " + f"Got: [{reference_dtype}, {arg_val.dtype}]" + ), + ) return False return True def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): - # Check outputs are valid and have the same dtypes + # Check outputs are valid node_val = node.meta.get("val", None) - reference_dtype = None if node_val is None: return True @@ -201,12 +206,6 @@ def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): if val.dtype not in valid_dtypes: return False - # Check for mixed dtypes - if reference_dtype is None: - reference_dtype = val.dtype - elif val.dtype != reference_dtype: - return False - return True def _check_node_has_valid_dtype(self, node): From 2ddc16ace67d23895ecf8da36058a47dc5de57ed Mon Sep 17 00:00:00 2001 From: Zuby Afzal Date: Wed, 26 Mar 2025 19:32:51 -0700 Subject: [PATCH 3/3] Improve mixed dtype logging with op name --- .../xnnpack/partition/config/xnnpack_config.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py index c30098d8805..df6067a7d68 100644 --- a/backends/xnnpack/partition/config/xnnpack_config.py +++ b/backends/xnnpack/partition/config/xnnpack_config.py @@ -175,14 +175,22 @@ def _check_inputs_are_valid_dtypes(self, node, valid_dtypes): if arg_val.dtype not in valid_dtypes: return False - # Check for mixed dtypes + # Use the first dtype as reference reference_dtype = reference_dtype or arg_val.dtype + + # Check for mixed dtypes if arg_val.dtype != reference_dtype: + # Get op name if the attribute exists, otherwise use the full node target for logging + op_name = ( + node.target.__name__ + if hasattr(node.target, "__name__") + else str(node.target) + ) why( node, reason=( - f"{node.target} does not support mixed input dtypes. " - f"Got: [{reference_dtype}, {arg_val.dtype}]" + f"{op_name} does not support mixed input dtypes, " + f"got: [{reference_dtype}, {arg_val.dtype}]" ), ) return False