1+ # coding=utf-8
2+
3+ # SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
4+ # SPDX-License-Identifier: BSD-3-Clause
5+ #
6+ # Redistribution and use in source and binary forms, with or without
7+ # modification, are permitted provided that the following conditions are met:
8+ #
9+ # 1. Redistributions of source code must retain the above copyright notice, this
10+ # list of conditions and the following disclaimer.
11+ #
12+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13+ # this list of conditions and the following disclaimer in the documentation
14+ # and/or other materials provided with the distribution.
15+ #
16+ # 3. Neither the name of the copyright holder nor the names of its
17+ # contributors may be used to endorse or promote products derived from
18+ # this software without specific prior written permission.
19+ #
20+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30+ #
31+
32+ import torch
33+
34+ def compare_tensors (msg , tensor1 , tensor2 , atol = 1e-8 , rtol = 1e-5 , verbose = False ):
35+
36+ # some None checks
37+ if tensor1 is None and tensor2 is None :
38+ allclose = True
39+ elif tensor1 is None and tensor2 is not None :
40+ allclose = False
41+ if verbose :
42+ print (f"tensor1 is None and tensor2 is not None" )
43+ elif tensor1 is not None and tensor2 is None :
44+ allclose = False
45+ if verbose :
46+ print (f"tensor1 is not None and tensor2 is None" )
47+ else :
48+ diff = torch .abs (tensor1 - tensor2 )
49+ abs_diff = torch .mean (diff , dim = 0 )
50+ rel_diff = torch .mean (diff / torch .clamp (torch .abs (tensor2 ), min = 1e-6 ), dim = 0 )
51+ allclose = torch .allclose (tensor1 , tensor2 , atol = atol , rtol = rtol )
52+ if not allclose and verbose :
53+ print (f"Absolute difference on { msg } : min = { abs_diff .min ()} , mean = { abs_diff .mean ()} , max = { abs_diff .max ()} " )
54+ print (f"Relative difference on { msg } : min = { rel_diff .min ()} , mean = { rel_diff .mean ()} , max = { rel_diff .max ()} " )
55+ print (f"Element values with max difference on { msg } : { tensor1 .flatten ()[diff .argmax ()]} and { tensor2 .flatten ()[diff .argmax ()]} " )
56+ # find violating entry
57+ worst_diff = torch .argmax (diff - (atol + rtol * torch .abs (tensor2 )))
58+ diff_bad = diff .flatten ()[worst_diff ].item ()
59+ tensor2_abs_bad = torch .abs (tensor2 ).flatten ()[worst_diff ].item ()
60+ print (f"Worst allclose condition violation: { diff_bad } <= { atol } + { rtol } * { tensor2_abs_bad } = { atol + rtol * tensor2_abs_bad } " )
61+
62+ return allclose
0 commit comments