Skip to content

Commit 7e6b36c

Browse files
committed
chore: add missing scripts
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent beeac7c commit 7e6b36c

File tree

3 files changed

+42
-0
lines changed

3 files changed

+42
-0
lines changed

tests/py/hw/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
3+
COSINE_THRESHOLD=0.99
4+
5+
def cosine_similarity(gt_tensor, pred_tensor):
6+
gt_tensor = gt_tensor.flatten().to(torch.float32)
7+
pred_tensor = pred_tensor.flatten().to(torch.float32)
8+
if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0:
9+
if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True):
10+
return 1.0
11+
res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6)
12+
res = res.cpu().detach().item()
13+
14+
return res

tests/py/integrations/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
3+
COSINE_THRESHOLD=0.99
4+
5+
def cosine_similarity(gt_tensor, pred_tensor):
6+
gt_tensor = gt_tensor.flatten().to(torch.float32)
7+
pred_tensor = pred_tensor.flatten().to(torch.float32)
8+
if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0:
9+
if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True):
10+
return 1.0
11+
res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6)
12+
res = res.cpu().detach().item()
13+
14+
return res

tests/py/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
3+
COSINE_THRESHOLD=0.99
4+
5+
def cosine_similarity(gt_tensor, pred_tensor):
6+
gt_tensor = gt_tensor.flatten().to(torch.float32)
7+
pred_tensor = pred_tensor.flatten().to(torch.float32)
8+
if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0:
9+
if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True):
10+
return 1.0
11+
res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6)
12+
res = res.cpu().detach().item()
13+
14+
return res

0 commit comments

Comments
 (0)