Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions backends/xnnpack/partition/config/xnnpack_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ def check_common_constraints(
return True

def _check_inputs_are_valid_dtypes(self, node, valid_dtypes):
# Check inputs are valid dtypes
# Check inputs are valid and have the same dtypes
# Gather all args which are nodes
args_to_check = []
reference_dtype = None
for arg in node.args:
if isinstance(arg, list) or isinstance(arg, tuple):
for item in arg:
Expand Down Expand Up @@ -174,11 +175,32 @@ def _check_inputs_are_valid_dtypes(self, node, valid_dtypes):
if arg_val.dtype not in valid_dtypes:
return False

# Use the first dtype as reference
reference_dtype = reference_dtype or arg_val.dtype

# Check for mixed dtypes
if arg_val.dtype != reference_dtype:
# Get op name if the attribute exists, otherwise use the full node target for logging
op_name = (
node.target.__name__
if hasattr(node.target, "__name__")
else str(node.target)
)
why(
node,
reason=(
f"{op_name} does not support mixed input dtypes, "
f"got: [{reference_dtype}, {arg_val.dtype}]"
),
)
return False

return True

def _check_outputs_are_valid_dtypes(self, node, valid_dtypes):
# Check outputs are valid dtype
# Check outputs are valid
node_val = node.meta.get("val", None)

if node_val is None:
return True

Expand Down
Loading