|
1 | | -# numpy |
| 1 | +import math |
| 2 | +import random |
| 3 | +import unittest |
| 4 | +from functools import reduce |
| 5 | + |
2 | 6 | import torch |
| 7 | + |
3 | 8 | import intel_pytorch_extension as ipex |
4 | | -# import pcl_embedding_bag |
5 | | -# import time |
6 | | - |
7 | | -def interact_fusion(x, ly): |
8 | | - A = [x] + ly |
9 | | - R = ipex.interaction(*A) |
10 | | - return R |
11 | | - |
12 | | -def interact_features(x, ly): |
13 | | - (batch_size, d) = x.shape |
14 | | - T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d)) |
15 | | - # Z = pcl_embedding_bag.bdot(T) |
16 | | - Z = torch.bmm(T, torch.transpose(T, 1, 2)) |
17 | | - _, ni, nj = Z.shape |
18 | | - offset = 0 |
19 | | - li = torch.tensor([i for i in range(ni) for j in range(i + offset)], device=ipex.DEVICE) |
20 | | - lj = torch.tensor([j for i in range(nj) for j in range(i + offset)], device=ipex.DEVICE) |
21 | | - Zflat = Z[:, li, lj] |
22 | | - # concatenate dense features and interactions |
23 | | - R = torch.cat([x] + [Zflat], dim=1) |
24 | | - return R |
25 | | - |
26 | | -def run(dtype='float32'): |
27 | | - print("##################### testing with %s"% str(dtype)) |
28 | | - x1 = torch.randn([2048, 128], device=ipex.DEVICE).to(dtype).clone().detach().requires_grad_() |
29 | | - x2 = x1.clone().detach().requires_grad_() |
30 | | - ly1 = [] |
31 | | - ly2 = [] |
32 | | - for i in range(0, 26): |
33 | | - V = torch.randn([2048, 128], device=ipex.DEVICE).to(dtype).clone().detach().requires_grad_() |
34 | | - ly1.append(V) |
35 | | - ly2.append(V.clone().detach().requires_grad_()) |
36 | | - |
37 | | - print("##################### interaction forward") |
38 | | - A = interact_fusion(x1, ly1) |
39 | | - B = interact_features(x2, ly2) |
40 | | - if(A.allclose(B, rtol=1e-5, atol=1e-5)): |
41 | | - print("##################### interaction forward PASS") |
42 | | - else: |
43 | | - print("##################### interaction forward FAIL") |
44 | | - |
45 | | - print("##################### interaction backward") |
46 | | - A.mean().backward() |
47 | | - B.mean().backward() |
48 | | - ret = x1.grad.allclose(x2.grad, rtol=1e-5, atol=1e-5) |
49 | | - ret = ret and all(ly1[i].grad.allclose(ly2[i].grad, rtol=1e-5, atol=1e-5) for i in range(0, 26)) |
50 | | - if (ret): |
51 | | - print("##################### interaction backward PASS") |
52 | | - else: |
53 | | - print("##################### interaction backward FAIL") |
54 | | - |
55 | | -#dtypes=[torch.float32, torch.bfloat16] |
56 | | -dtypes=[torch.float32] |
57 | | -for d in dtypes: |
58 | | - run(d) |
| 9 | + |
| 10 | +import torch.nn as nn |
| 11 | +import torch.backends.cudnn as cudnn |
| 12 | +from torch.nn import Parameter |
| 13 | +import torch.nn.functional as F |
| 14 | +from torch.autograd import gradcheck |
| 15 | +from torch.autograd.gradcheck import gradgradcheck |
| 16 | +from torch._six import inf, nan |
| 17 | + |
| 18 | +from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \ |
| 19 | + TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \ |
| 20 | + IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, do_test_empty_full, \ |
| 21 | + IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \ |
| 22 | + skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf |
| 23 | + |
| 24 | +class TestInteractionCases(TestCase): |
| 25 | + def test_interaction(self): |
| 26 | + def interact_fusion(x, ly): |
| 27 | + A = [x] + ly |
| 28 | + R = ipex.interaction(*A) |
| 29 | + return R |
| 30 | + |
| 31 | + def interact_features(x, ly): |
| 32 | + (batch_size, d) = x.shape |
| 33 | + T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d)) |
| 34 | + # Z = pcl_embedding_bag.bdot(T) |
| 35 | + Z = torch.bmm(T, torch.transpose(T, 1, 2)) |
| 36 | + _, ni, nj = Z.shape |
| 37 | + offset = 0 |
| 38 | + li = torch.tensor([i for i in range(ni) for j in range(i + offset)], device=ipex.DEVICE) |
| 39 | + lj = torch.tensor([j for i in range(nj) for j in range(i + offset)], device=ipex.DEVICE) |
| 40 | + Zflat = Z[:, li, lj] |
| 41 | + # concatenate dense features and interactions |
| 42 | + R = torch.cat([x] + [Zflat], dim=1) |
| 43 | + return R |
| 44 | + |
| 45 | + dtypes=[torch.float32] |
| 46 | + for dtype in dtypes: |
| 47 | + x1 = torch.randn([2048, 128], device=ipex.DEVICE).to(dtype).clone().detach().requires_grad_() |
| 48 | + x2 = x1.clone().detach().requires_grad_() |
| 49 | + ly1 = [] |
| 50 | + ly2 = [] |
| 51 | + for i in range(0, 26): |
| 52 | + V = torch.randn([2048, 128], device=ipex.DEVICE).to(dtype).clone().detach().requires_grad_() |
| 53 | + ly1.append(V) |
| 54 | + ly2.append(V.clone().detach().requires_grad_()) |
| 55 | + |
| 56 | + A = interact_fusion(x1, ly1) |
| 57 | + B = interact_features(x2, ly2) |
| 58 | + self.assertEqual(A, B) |
| 59 | + |
| 60 | + A.mean().backward() |
| 61 | + B.mean().backward() |
| 62 | + self.assertEqual(x1.grad, x2.grad) |
| 63 | + for i in range(0, 26): |
| 64 | + self.assertEqual(ly1[i].grad, ly2[i].grad) |
| 65 | + |
| 66 | +if __name__ == '__main__': |
| 67 | + test = unittest.main() |
0 commit comments