|  | 
| 1 | 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. | 
| 2 |  | -# Copyright 2024-2025 Arm Limited and/or its affiliates. | 
| 3 | 2 | # All rights reserved. | 
|  | 3 | +# Copyright 2024-2025 Arm Limited and/or its affiliates. | 
| 4 | 4 | # | 
| 5 | 5 | # This source code is licensed under the BSD-style license found in the | 
| 6 | 6 | # LICENSE file in the root directory of this source tree. | 
| @@ -714,23 +714,30 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): | 
| 714 | 714 |             assert ( | 
| 715 | 715 |                 ref.shape == model.shape | 
| 716 | 716 |             ), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}" | 
| 717 |  | -            assert torch.allclose( | 
| 718 |  | -                model, | 
| 719 |  | -                ref, | 
| 720 |  | -                atol=atol, | 
| 721 |  | -                rtol=rtol, | 
| 722 |  | -            ), ( | 
| 723 |  | -                f"Output {i} does not match reference output.\n" | 
| 724 |  | -                f"\tGiven atol: {atol}, rtol: {rtol}.\n" | 
| 725 |  | -                f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n" | 
| 726 |  | -                f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n" | 
| 727 |  | -                f"\t-- Model vs. Reference --\n" | 
| 728 |  | -                f"\t Numel: {model.numel()}, {ref.numel()}\n" | 
| 729 |  | -                f"\tMedian: {model.median()}, {ref.median()}\n" | 
| 730 |  | -                f"\t  Mean: {model.mean()}, {ref.mean()}\n" | 
| 731 |  | -                f"\t   Max: {model.max()}, {ref.max()}\n" | 
| 732 |  | -                f"\t   Min: {model.min()}, {ref.min()}\n" | 
| 733 |  | -            ) | 
|  | 717 | +            if model.dtype == torch.bool: | 
|  | 718 | +                assert torch.equal(model, ref), ( | 
|  | 719 | +                    f"Output {i} (bool tensor) does not match reference output.\n" | 
|  | 720 | +                    f"\tShape: {model.shape}\n" | 
|  | 721 | +                    f"\tMismatched count: {(model != ref).sum().item()} / {model.numel()}\n" | 
|  | 722 | +                ) | 
|  | 723 | +            else: | 
|  | 724 | +                assert torch.allclose( | 
|  | 725 | +                    model, | 
|  | 726 | +                    ref, | 
|  | 727 | +                    atol=atol, | 
|  | 728 | +                    rtol=rtol, | 
|  | 729 | +                ), ( | 
|  | 730 | +                    f"Output {i} does not match reference output.\n" | 
|  | 731 | +                    f"\tGiven atol: {atol}, rtol: {rtol}.\n" | 
|  | 732 | +                    f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n" | 
|  | 733 | +                    f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n" | 
|  | 734 | +                    f"\t-- Model vs. Reference --\n" | 
|  | 735 | +                    f"\t Numel: {model.numel()}, {ref.numel()}\n" | 
|  | 736 | +                    f"\tMedian: {model.median()}, {ref.median()}\n" | 
|  | 737 | +                    f"\t  Mean: {model.mean()}, {ref.mean()}\n" | 
|  | 738 | +                    f"\t   Max: {model.max()}, {ref.max()}\n" | 
|  | 739 | +                    f"\t   Min: {model.min()}, {ref.min()}\n" | 
|  | 740 | +                ) | 
| 734 | 741 | 
 | 
| 735 | 742 |     @staticmethod | 
| 736 | 743 |     def _compare_outputs( | 
|  | 
0 commit comments