From 9b12695de7c5b458347d404ca059c80765d0f4c8 Mon Sep 17 00:00:00 2001 From: Fang-Ching Date: Tue, 13 May 2025 13:26:47 +0100 Subject: [PATCH] Arm backend: Avoid subtraction error on boolean tensors in error diff logging Add a dtype check before diff calculation to avoid runtime errors since boolean tensors do not support arithmetic operations like subtraction. Reports the mismatch count for boolean tensors. Signed-off-by: Fang-Ching Change-Id: I9575f9fe3bd2e4dc0a1be45aa9271c84ebc57bad --- .../arm/test/tester/analyze_output_utils.py | 24 +++++++---- backends/xnnpack/test/tester/tester.py | 43 +++++++++++-------- 2 files changed, 41 insertions(+), 26 deletions(-) diff --git a/backends/arm/test/tester/analyze_output_utils.py b/backends/arm/test/tester/analyze_output_utils.py index 1ec0f2304aa..15a6db9a878 100644 --- a/backends/arm/test/tester/analyze_output_utils.py +++ b/backends/arm/test/tester/analyze_output_utils.py @@ -154,6 +154,13 @@ def print_error_diffs( output_str += f"BATCH {n}\n" result_batch = result[n, :, :, :] reference_batch = reference[n, :, :, :] + + if reference_batch.dtype == torch.bool or result_batch.dtype == torch.bool: + mismatches = (reference_batch != result_batch).sum().item() + total = reference_batch.numel() + output_str += f"(BOOLEAN tensor) {mismatches} / {total} elements differ ({mismatches / total:.2%})\n" + continue + is_close = torch.allclose(result_batch, reference_batch, rtol, atol) if is_close: output_str += ".\n" @@ -180,14 +187,15 @@ def print_error_diffs( output_str += _print_elements( result[n, :, :, :], reference[n, :, :, :], C, H, W, rtol, atol ) - - reference_range = torch.max(reference) - torch.min(reference) - diff = torch.abs(reference - result).flatten() - diff = diff[diff.nonzero()] - if not len(diff) == 0: - diff_percent = diff / reference_range - output_str += "\nMEAN MEDIAN MAX MIN (error as % of reference output range)\n" - output_str += f"{torch.mean(diff_percent):<8.2%} {torch.median(diff_percent):<8.2%} {torch.max(diff_percent):<8.2%} {torch.min(diff_percent):<8.2%}\n" + # Only compute numeric error metrics if tensor is not boolean + if reference.dtype != torch.bool and result.dtype != torch.bool: + reference_range = torch.max(reference) - torch.min(reference) + diff = torch.abs(reference - result).flatten() + diff = diff[diff.nonzero()] + if not len(diff) == 0: + diff_percent = diff / reference_range + output_str += "\nMEAN MEDIAN MAX MIN (error as % of reference output range)\n" + output_str += f"{torch.mean(diff_percent):<8.2%} {torch.median(diff_percent):<8.2%} {torch.max(diff_percent):<8.2%} {torch.min(diff_percent):<8.2%}\n" # Over-engineer separators to match output width lines = output_str.split("\n") diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index fa8edd3e03c..30e63be5f4c 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -714,23 +714,30 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): assert ( ref.shape == model.shape ), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}" - assert torch.allclose( - model, - ref, - atol=atol, - rtol=rtol, - ), ( - f"Output {i} does not match reference output.\n" - f"\tGiven atol: {atol}, rtol: {rtol}.\n" - f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n" - f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n" - f"\t-- Model vs. Reference --\n" - f"\t Numel: {model.numel()}, {ref.numel()}\n" - f"\tMedian: {model.median()}, {ref.median()}\n" - f"\t Mean: {model.mean()}, {ref.mean()}\n" - f"\t Max: {model.max()}, {ref.max()}\n" - f"\t Min: {model.min()}, {ref.min()}\n" - ) + if model.dtype == torch.bool: + assert torch.equal(model, ref), ( + f"Output {i} (bool tensor) does not match reference output.\n" + f"\tShape: {model.shape}\n" + f"\tMismatched count: {(model != ref).sum().item()} / {model.numel()}\n" + ) + else: + assert torch.allclose( + model, + ref, + atol=atol, + rtol=rtol, + ), ( + f"Output {i} does not match reference output.\n" + f"\tGiven atol: {atol}, rtol: {rtol}.\n" + f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n" + f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n" + f"\t-- Model vs. Reference --\n" + f"\t Numel: {model.numel()}, {ref.numel()}\n" + f"\tMedian: {model.median()}, {ref.median()}\n" + f"\t Mean: {model.mean()}, {ref.mean()}\n" + f"\t Max: {model.max()}, {ref.max()}\n" + f"\t Min: {model.min()}, {ref.min()}\n" + ) @staticmethod def _compare_outputs(