|
1 | 1 | import copy |
2 | 2 |
|
3 | 3 | import torch |
4 | | -import torchsparse_cuda |
| 4 | +import torchsparse_backend |
5 | 5 | from torch.autograd import Function |
6 | 6 | from torchsparse import * |
7 | 7 | from torchsparse.nn.functional.convert_neighbor_map import * |
@@ -35,9 +35,9 @@ def forward(ctx, |
35 | 35 | device=features.device) |
36 | 36 |
|
37 | 37 | if 'cuda' in str(features.device): |
38 | | - torchsparse_cuda.sparseconv_forward(features, out, kernel, |
39 | | - neighbor_map, neighbor_offset, |
40 | | - transpose) |
| 38 | + torchsparse_backend.sparseconv_forward(features, out, kernel, |
| 39 | + neighbor_map, |
| 40 | + neighbor_offset, transpose) |
41 | 41 | else: |
42 | 42 | # use the native pytorch XLA APIs for the TPU. |
43 | 43 | cur_st = 0 |
@@ -69,10 +69,11 @@ def backward(ctx, grad_out): |
69 | 69 | grad_kernel = torch.zeros(K, c_in, c_out, device=kernel.device) |
70 | 70 |
|
71 | 71 | if 'cuda' in str(features.device): |
72 | | - torchsparse_cuda.sparseconv_backward(features, grad_features, |
73 | | - grad_out.contiguous(), kernel, |
74 | | - grad_kernel, neighbor_map, |
75 | | - neighbor_offset, transpose) |
| 72 | + torchsparse_backend.sparseconv_backward(features, grad_features, |
| 73 | + grad_out.contiguous(), |
| 74 | + kernel, grad_kernel, |
| 75 | + neighbor_map, |
| 76 | + neighbor_offset, transpose) |
76 | 77 | else: |
77 | 78 | raise NotImplementedError |
78 | 79 | return grad_features, grad_kernel, None, None, None, None |
|
0 commit comments