Skip to content

Commit c6f3103

Browse files
committed
chore: move to cosine similarity comparison
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 5756169 commit c6f3103

File tree

3 files changed

+28
-30
lines changed

3 files changed

+28
-30
lines changed

tests/cpp/test_collections.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ TEST(CppAPITests, TestCollectionStandardTensorInput) {
4242
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
4343
auto trt_out = trt_mod.forward(inputs_);
4444

45-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5));
45+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(out.toTensor(), trt_out.toTensor(), 0.99));
4646
}
4747

4848
TEST(CppAPITests, TestCollectionTupleInput) {
@@ -85,7 +85,7 @@ TEST(CppAPITests, TestCollectionTupleInput) {
8585
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
8686
auto trt_out = trt_mod.forward(complex_inputs);
8787

88-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5));
88+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(out.toTensor(), trt_out.toTensor(), 0.99));
8989
}
9090

9191
TEST(CppAPITests, TestCollectionListInput) {
@@ -144,7 +144,7 @@ TEST(CppAPITests, TestCollectionListInput) {
144144
LOG_DEBUG("Finish compile");
145145
auto trt_out = trt_mod.forward(complex_inputs);
146146

147-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5));
147+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(out.toTensor(), trt_out.toTensor(), 0.99));
148148
}
149149

150150
TEST(CppAPITests, TestCollectionTupleInputOutput) {
@@ -192,7 +192,7 @@ TEST(CppAPITests, TestCollectionTupleInputOutput) {
192192
auto trt_out = trt_mod.forward(complex_inputs);
193193

194194
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(
195-
out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor(), 1e-5));
195+
out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor(), 1e-5));
196196
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(
197197
out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor(), 1e-5));
198198
}
@@ -317,4 +317,4 @@ TEST(CppAPITests, TestCollectionComplexModel) {
317317
out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor(), 1e-5));
318318
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(
319319
out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor(), 1e-5));
320-
}
320+
}

tests/py/api/test_collections.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torchvision.models as models
55
import os
6+
from utils import cosine_similarity, COSINE_THRESHOLD
67

78

89
def find_repo_root(max_depth=10):
@@ -40,12 +41,8 @@ def test_compile(self):
4041
}
4142

4243
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
43-
same = (
44-
(trt_mod(self.input, self.input) - self.model(self.input, self.input))
45-
.abs()
46-
.max()
47-
)
48-
self.assertTrue(same < 2e-2)
44+
cos_sim = cosine_similarity(self.model(self.input, self.input), trt_mod(self.input, self.input))
45+
self.assertTrue(cos_sim > COSINE_THRESHOLD, msg=f"standard_tensor_input_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}")
4946

5047

5148
class TestTupleInput(unittest.TestCase):
@@ -68,12 +65,8 @@ def test_compile(self):
6865
}
6966

7067
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
71-
same = (
72-
(trt_mod((self.input, self.input)) - self.model((self.input, self.input)))
73-
.abs()
74-
.max()
75-
)
76-
self.assertTrue(same < 2e-2)
68+
cos_sim = cosine_similarity(self.model((self.input, self.input)), trt_mod((self.input, self.input)))
69+
self.assertTrue(cos_sim > COSINE_THRESHOLD, msg=f"tuple_input_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}")
7770

7871

7972
class TestListInput(unittest.TestCase):
@@ -94,12 +87,8 @@ def test_compile(self):
9487
}
9588

9689
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
97-
same = (
98-
(trt_mod([self.input, self.input]) - self.model([self.input, self.input]))
99-
.abs()
100-
.max()
101-
)
102-
self.assertTrue(same < 2e-2)
90+
cos_sim = cosine_similarity(self.model([self.input, self.input]), trt_mod([self.input, self.input]))
91+
self.assertTrue(cos_sim > COSINE_THRESHOLD, msg=f"list_input_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}")
10392

10493

10594
class TestTupleInputOutput(unittest.TestCase):
@@ -124,8 +113,9 @@ def test_compile(self):
124113
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
125114
trt_out = trt_mod((self.input, self.input))
126115
pyt_out = self.model((self.input, self.input))
127-
results = [(t - p).abs().max() < 2e-2 for (t, p) in zip(trt_out, pyt_out)]
128-
self.assertTrue(all(results))
116+
for (t, p) in zip(trt_out, pyt_out):
117+
cos_sim = cosine_similarity(t, p)
118+
self.assertTrue(cos_sim > COSINE_THRESHOLD, msg=f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}")
129119

130120

131121
class TestListInputOutput(unittest.TestCase):
@@ -150,8 +140,10 @@ def test_compile(self):
150140
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
151141
trt_out = trt_mod((self.input, self.input))
152142
pyt_out = self.model((self.input, self.input))
153-
results = [(t - p).abs().max() < 2e-2 for (t, p) in zip(trt_out, pyt_out)]
154-
self.assertTrue(all(results))
143+
144+
for (t, p) in zip(trt_out, pyt_out):
145+
cos_sim = cosine_similarity(t, p)
146+
self.assertTrue(cos_sim > COSINE_THRESHOLD, msg=f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}")
155147

156148

157149
class TestListInputTupleOutput(unittest.TestCase):
@@ -176,8 +168,9 @@ def test_compile(self):
176168
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
177169
trt_out = trt_mod((self.input, self.input))
178170
pyt_out = self.model((self.input, self.input))
179-
results = [(t - p).abs().max() < 2e-2 for (t, p) in zip(trt_out, pyt_out)]
180-
self.assertTrue(all(results))
171+
for (t, p) in zip(trt_out, pyt_out):
172+
cos_sim = cosine_similarity(t, p)
173+
self.assertTrue(cos_sim > COSINE_THRESHOLD, msg=f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}")
181174

182175

183176
if __name__ == "__main__":

tests/py/api/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
COSINE_THRESHOLD=0.99
44

55
def cosine_similarity(gt_tensor, pred_tensor):
6-
res = torch.nn.functional.cosine_similarity(gt_tensor.flatten().to(torch.float32), pred_tensor.flatten().to(torch.float32), dim=0, eps=1e-6)
6+
gt_tensor = gt_tensor.flatten().to(torch.float32)
7+
pred_tensor = pred_tensor.flatten().to(torch.float32)
8+
if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0:
9+
if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True):
10+
return 1.0
11+
res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6)
712
res = res.cpu().detach().item()
813

914
return res

0 commit comments

Comments
 (0)