Skip to content

Commit 1d747aa

Browse files
committed
fixing rebase
1 parent 01b2c30 commit 1d747aa

File tree

3 files changed

+62
-2
lines changed

3 files changed

+62
-2
lines changed

tests/test_attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
import torch
3838
from torch.library import opcheck
39-
import torch.nn as nn
4039

4140
# from torch.autograd import gradcheck
4241
from torch_harmonics import AttentionS2, NeighborhoodAttentionS2

tests/test_permute.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import unittest
3434
from parameterized import parameterized, parameterized_class
3535

36-
3736
import torch
3837
from torch.library import opcheck
3938
from torch_harmonics.utils import permute_to_0231, permute_to_0312

tests/testutils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# coding=utf-8
2+
3+
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# Redistribution and use in source and binary forms, with or without
7+
# modification, are permitted provided that the following conditions are met:
8+
#
9+
# 1. Redistributions of source code must retain the above copyright notice, this
10+
# list of conditions and the following disclaimer.
11+
#
12+
# 2. Redistributions in binary form must reproduce the above copyright notice,
13+
# this list of conditions and the following disclaimer in the documentation
14+
# and/or other materials provided with the distribution.
15+
#
16+
# 3. Neither the name of the copyright holder nor the names of its
17+
# contributors may be used to endorse or promote products derived from
18+
# this software without specific prior written permission.
19+
#
20+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30+
#
31+
32+
import torch
33+
34+
def compare_tensors(msg, tensor1, tensor2, atol=1e-8, rtol=1e-5, verbose=False):
35+
36+
# some None checks
37+
if tensor1 is None and tensor2 is None:
38+
allclose = True
39+
elif tensor1 is None and tensor2 is not None:
40+
allclose = False
41+
if verbose:
42+
print(f"tensor1 is None and tensor2 is not None")
43+
elif tensor1 is not None and tensor2 is None:
44+
allclose = False
45+
if verbose:
46+
print(f"tensor1 is not None and tensor2 is None")
47+
else:
48+
diff = torch.abs(tensor1 - tensor2)
49+
abs_diff = torch.mean(diff, dim=0)
50+
rel_diff = torch.mean(diff / torch.clamp(torch.abs(tensor2), min=1e-6), dim=0)
51+
allclose = torch.allclose(tensor1, tensor2, atol=atol, rtol=rtol)
52+
if not allclose and verbose:
53+
print(f"Absolute difference on {msg}: min = {abs_diff.min()}, mean = {abs_diff.mean()}, max = {abs_diff.max()}")
54+
print(f"Relative difference on {msg}: min = {rel_diff.min()}, mean = {rel_diff.mean()}, max = {rel_diff.max()}")
55+
print(f"Element values with max difference on {msg}: {tensor1.flatten()[diff.argmax()]} and {tensor2.flatten()[diff.argmax()]}")
56+
# find violating entry
57+
worst_diff = torch.argmax(diff - (atol + rtol * torch.abs(tensor2)))
58+
diff_bad = diff.flatten()[worst_diff].item()
59+
tensor2_abs_bad = torch.abs(tensor2).flatten()[worst_diff].item()
60+
print(f"Worst allclose condition violation: {diff_bad} <= {atol} + {rtol} * {tensor2_abs_bad} = {atol + rtol * tensor2_abs_bad}")
61+
62+
return allclose

0 commit comments

Comments
 (0)