Skip to content

Commit 313c9b0

Browse files
kentang-mitZhijian Liu
andauthored
Add generalized sparse convolution (#77)
Co-authored-by: Zhijian Liu <[email protected]>
1 parent 63a67ed commit 313c9b0

24 files changed

+393
-322
lines changed

setup.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@
2828
'torchsparse/src/interpolation/devox_deterministic.cpp',
2929
'torchsparse/src/interpolation/devox_deterministic_gpu.cu',
3030
'torchsparse/src/interpolation/devox_cpu.cpp',
31-
'torchsparse/src/others/convert_neighbor_map.cpp',
32-
'torchsparse/src/others/convert_neighbor_map_gpu.cu',
33-
'torchsparse/src/others/convert_neighbor_map_cpu.cpp',
3431
'torchsparse/src/others/count.cpp',
3532
'torchsparse/src/others/count_gpu.cu',
3633
'torchsparse/src/others/count_cpu.cpp',
@@ -44,7 +41,6 @@
4441
'torchsparse/src/hash/hash_cpu.cpp',
4542
'torchsparse/src/hashmap/hashmap_cpu.cpp',
4643
'torchsparse/src/interpolation/devox_cpu.cpp',
47-
'torchsparse/src/others/convert_neighbor_map_cpu.cpp',
4844
'torchsparse/src/others/insertion_cpu.cpp',
4945
'torchsparse/src/others/query_cpu.cpp',
5046
'torchsparse/src/others/count_cpu.cpp'

torchsparse/nn/functional/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .activation import *
22
from .conv import *
3-
from .convert_neighbor_map import *
3+
from .squeeze_nmap import *
44
from .count import *
55
from .crop import *
66
from .devox import *

torchsparse/nn/functional/activation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import functools
22

33
from torch.nn import functional as F
4-
from torchsparse.sparse_tensor import *
4+
from torchsparse.sparse_tensor import SparseTensor
55

66
__all__ = ['spact', 'sprelu', 'spleaky_relu']
77

torchsparse/nn/functional/conv.py

Lines changed: 66 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import torchsparse_backend
55
from torch.autograd import Function
66
from torch.cuda.amp import custom_fwd, custom_bwd
7-
from torchsparse import *
8-
from torchsparse.nn.functional.convert_neighbor_map import *
9-
from torchsparse.nn.functional.downsample import *
10-
from torchsparse.nn.functional.hash import *
11-
from torchsparse.nn.functional.query import *
12-
from torchsparse.utils.kernel_region import *
7+
from torchsparse import SparseTensor
8+
from torchsparse.nn import functional as spF
9+
from torchsparse.utils.helpers import make_tuple
10+
from torchsparse.utils.kernel import KernelRegion, KernelMapKey
11+
12+
from typing import Union, List, Tuple, Optional
1313

1414
__all__ = ['conv3d']
1515

@@ -70,8 +70,15 @@ def backward(ctx, grad_out):
7070
features, kernel, neighbor_map, neighbor_offset, transpose = ctx.for_backwards
7171
K, c_in, c_out = kernel.size()
7272
N_in = features.size(0)
73-
grad_features = torch.zeros(N_in, c_in, device=features.device, dtype=features.dtype)
74-
grad_kernel = torch.zeros(K, c_in, c_out, device=kernel.device, dtype=features.dtype)
73+
grad_features = torch.zeros(N_in,
74+
c_in,
75+
device=features.device,
76+
dtype=features.dtype)
77+
grad_kernel = torch.zeros(K,
78+
c_in,
79+
c_out,
80+
device=kernel.device,
81+
dtype=features.dtype)
7582

7683
if 'cuda' in str(features.device):
7784
torchsparse_backend.sparseconv_backward(features, grad_features,
@@ -87,18 +94,24 @@ def backward(ctx, grad_out):
8794
sparseconv_op = SpConvolution.apply
8895

8996

90-
def conv3d(inputs,
91-
kernel,
92-
kernel_size,
93-
bias=None,
94-
stride=1,
95-
dilation=1,
96-
transpose=False):
97+
def conv3d(inputs: SparseTensor,
98+
kernel: torch.Tensor,
99+
kernel_size: Union[int, List[int], Tuple[int, int, int]],
100+
bias: Optional[torch.Tensor] = None,
101+
stride: Union[int, List[int], Tuple[int, int, int]] = 1,
102+
dilation: Union[int, List[int], Tuple[int, int, int]] = 1,
103+
transpose: bool = False) -> SparseTensor:
97104
features = inputs.F
98105
coords = inputs.C
99106
cur_stride = inputs.s
100107

101-
if kernel_size == 1 and stride == 1 and dilation == 1:
108+
# convert to hashable types
109+
kernel_size = make_tuple(kernel_size)
110+
stride = make_tuple(stride)
111+
dilation = make_tuple(dilation)
112+
113+
if kernel_size == (1, 1, 1) and stride == (1, 1, 1) and dilation == (1, 1,
114+
1):
102115
output_features = features.matmul(kernel)
103116
if bias is not None:
104117
output_features += bias
@@ -107,34 +120,37 @@ def conv3d(inputs,
107120
output_tensor.kernel_maps = inputs.kernel_maps
108121
output_tensor.check()
109122
elif not transpose:
110-
kernel_map = inputs.kernel_maps.get(
111-
'k%s_os%d_s%d_d%d' % (kernel_size, cur_stride, stride, dilation),
112-
None)
123+
kernel_map_key = KernelMapKey(kernel_size, cur_stride, stride,
124+
dilation)
125+
kernel_map = inputs.kernel_maps.get(kernel_map_key, None)
113126

114-
if stride > 1:
127+
if any(x > 1 for x in stride):
115128
# do downsample
116129
kRegion = KernelRegion(kernel_size=kernel_size,
117130
tensor_stride=cur_stride)
118131
kOffset = kRegion.get_kernel_offset().to(features.device)
119-
new_coords = spdownsample(coords, stride * cur_stride)
120-
hash_query = sphash(new_coords, kOffset)
121-
hash_target = sphash(coords)
122-
idx_query = sphashquery(hash_query, hash_target)
123-
idx_query = list(convert_neighbor_map_gpu(idx_query))
132+
new_coords = spF.spdownsample(coords, stride, kernel_size,
133+
cur_stride)
134+
hash_query = spF.sphash(new_coords, kOffset)
135+
hash_target = spF.sphash(coords)
136+
idx_query = spF.sphashquery(hash_query, hash_target)
137+
idx_query = list(spF.squeeze_nmap(idx_query))
124138
idx_query[1] = idx_query[1].to('cpu')
125139
sizes = (features.shape[0], new_coords.shape[0])
126140
output_features = sparseconv_op(features, kernel, idx_query[0],
127141
idx_query[1], sizes, transpose)
128142
if bias is not None:
129143
output_features += bias
130-
output_tensor = SparseTensor(output_features, new_coords,
131-
cur_stride * stride)
144+
output_tensor = SparseTensor(
145+
output_features, new_coords,
146+
[a * b for a, b in zip(cur_stride, stride)])
132147
output_tensor.coord_maps = copy.deepcopy(inputs.coord_maps)
133148
output_tensor.check()
149+
150+
kernel_map_key = KernelMapKey(kernel_size, cur_stride, stride,
151+
dilation)
134152
output_tensor.kernel_maps = copy.deepcopy(inputs.kernel_maps)
135-
output_tensor.kernel_maps['k%s_os%d_s%d_d%d' %
136-
(kernel_size, cur_stride, stride,
137-
dilation)] = idx_query + [sizes]
153+
output_tensor.kernel_maps[kernel_map_key] = idx_query + [sizes]
138154

139155
else:
140156
if kernel_map is None:
@@ -144,10 +160,10 @@ def conv3d(inputs,
144160
kOffset = kRegion.get_kernel_offset().to(features.device)
145161
except:
146162
raise
147-
hash_query = sphash(coords, kOffset)
148-
hash_target = sphash(coords)
149-
idx_query = sphashquery(hash_query, hash_target)
150-
idx_query = list(convert_neighbor_map_gpu(idx_query))
163+
hash_query = spF.sphash(coords, kOffset)
164+
hash_target = spF.sphash(coords)
165+
idx_query = spF.sphashquery(hash_query, hash_target)
166+
idx_query = list(spF.squeeze_nmap(idx_query))
151167
idx_query[1] = idx_query[1].to('cpu')
152168
sizes = (features.shape[0], features.shape[0])
153169
output_features = sparseconv_op(features, kernel, idx_query[0],
@@ -159,9 +175,9 @@ def conv3d(inputs,
159175
output_tensor.coord_maps = inputs.coord_maps
160176
output_tensor.check()
161177
output_tensor.kernel_maps = copy.deepcopy(inputs.kernel_maps)
162-
output_tensor.kernel_maps['k%s_os%d_s%d_d%d' %
163-
(kernel_size, cur_stride, stride,
164-
dilation)] = idx_query + [sizes]
178+
kernel_map_key = KernelMapKey(kernel_size, cur_stride, stride,
179+
dilation)
180+
output_tensor.kernel_maps[kernel_map_key] = idx_query + [sizes]
165181
else:
166182
output_features = sparseconv_op(features, kernel,
167183
kernel_map[0], kernel_map[1],
@@ -176,17 +192,24 @@ def conv3d(inputs,
176192

177193
else:
178194
# do upsample
179-
original_stride = int(cur_stride / stride)
180-
kernel_map = inputs.kernel_maps.get(
181-
'k%s_os%d_s%d_d%d' %
182-
(kernel_size, original_stride, stride, dilation), None)
195+
196+
original_stride = tuple(
197+
[int(a / b) for a, b in zip(cur_stride, stride)])
198+
199+
kernel_map_key = KernelMapKey(kernel_size, original_stride, stride,
200+
dilation)
201+
kernel_map = inputs.kernel_maps.get(kernel_map_key, None)
202+
assert kernel_map is not None, f'{kernel_map_key} does not exist.'
183203
output_features = sparseconv_op(features, kernel, kernel_map[0],
184204
kernel_map[1], kernel_map[2],
185205
transpose)
186206
if bias is not None:
187207
output_features += bias
188-
output_tensor = SparseTensor(output_features,
189-
inputs.coord_maps[original_stride],
208+
209+
cur_coords = inputs.coord_maps.get(original_stride, None)
210+
assert cur_coords is not None, f'{original_stride} not in coord maps.'
211+
212+
output_tensor = SparseTensor(output_features, cur_coords,
190213
original_stride)
191214
output_tensor.coord_maps = inputs.coord_maps
192215
output_tensor.check()

torchsparse/nn/functional/convert_neighbor_map.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

torchsparse/nn/functional/downsample.py

Lines changed: 59 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,69 @@
22
import torchsparse_backend
33
from torch.autograd import Function
44
from torchsparse.nn.functional.hash import *
5-
from torchsparse.nn.functional.voxelize import spvoxelize
5+
from torchsparse.utils.kernel import KernelRegion
6+
from typing import Tuple, List, Union
67

78
__all__ = ['spdownsample']
89

910

10-
class DownsampleGPU(Function):
11-
@staticmethod
12-
def forward(ctx, coords, ratio):
13-
coords_float = coords[:, :3].float()
14-
# following Minkowski engine
15-
coords_new = torch.floor(torch.floor(coords_float / ratio) *
16-
ratio).int()
17-
coords_new = torch.cat([coords_new, coords[:, 3].view(-1, 1)], 1)
18-
coords_new_hash = sphash(coords_new)
19-
uq, inv, cnt = torch.unique(coords_new_hash,
20-
return_inverse=True,
21-
return_counts=True)
22-
inv = inv.int()
23-
cnt = cnt.int()
24-
# rounding is necessary
25-
# gpu
26-
if 'cuda' in str(coords.device):
27-
uq_coords = torch.round(spvoxelize(coords_new.float(), inv,
28-
cnt))
29-
elif 'cpu' in str(coords.device):
30-
uq_coords = torch.round(
31-
torchsparse_backend.cpu_insertion_forward(
32-
coords_new.float(), inv, cnt))
33-
else:
34-
device = coords.device
35-
uq_coords = torch.round(
36-
torchsparse_backend.cpu_insertion_forward(
37-
coords_new.float().cpu(), inv.cpu(), cnt.cpu()))
38-
uq_coords = uq_coords.to(device)
39-
uq_coords = uq_coords.int()
40-
41-
# Notice: corrds_new_hash cannot be directly used
42-
return uq_coords #, coords_new_hash
43-
11+
def spdownsample(
12+
coords: torch.Tensor,
13+
ratio: Union[int, List[int], Tuple[int, int, int]] = 2,
14+
kernel_size: Union[int, List[int], Tuple[int, int, int]] = 2,
15+
tensor_stride: Union[int, List[int], Tuple[int, int, int]] = 1
16+
) -> torch.Tensor:
4417

45-
downsample_gpu = DownsampleGPU.apply
18+
if not isinstance(ratio, int):
19+
ratio = torch.IntTensor(ratio).to(coords.device).unsqueeze(0)
20+
if not isinstance(tensor_stride, int):
21+
tensor_stride = torch.IntTensor(tensor_stride).to(
22+
coords.device).unsqueeze(0)
4623

24+
if isinstance(kernel_size, int) and isinstance(ratio, int):
25+
direct_downsample = kernel_size == ratio
26+
else:
27+
if isinstance(kernel_size, int):
28+
# ratio is a permutation of [1, 1, kernel_size]
29+
direct_downsample = (kernel_size == ratio.prod().item()) & \
30+
(torch.sum(ratio == kernel_size) == 1).item()
31+
else:
32+
direct_downsample = False
4733

48-
def spdownsample(coords, ratio):
49-
return downsample_gpu(coords, ratio)
34+
if direct_downsample:
35+
_ratio = ratio * tensor_stride
36+
new_coords = torch.cat(
37+
[coords[:, :3] // _ratio * _ratio, coords[:, 3:]], 1)
38+
return torch.unique(new_coords, dim=0)
39+
else:
40+
kernel_region = KernelRegion(kernel_size, tensor_stride, dilation=1)
41+
# kernel volume x 3
42+
kernel_offset = kernel_region.get_kernel_offset().to(coords.device)
43+
new_coords = coords[:, :3].unsqueeze(1).repeat(
44+
1, kernel_offset.size(0), 1) + kernel_offset
45+
# (N x kernel volume) x 4
46+
new_coords = torch.cat([
47+
coords[:, 3:].repeat(1, kernel_offset.size(0)).view(-1, 1),
48+
new_coords.view(-1, 3)
49+
],
50+
dim=1)
51+
new_ts = tensor_stride * ratio
52+
# only keep these coordinates that is multiple of new_ts.
53+
if isinstance(new_ts, torch.Tensor):
54+
new_ts = new_ts[0]
55+
new_coords = new_coords[
56+
(new_coords[:, 1] % new_ts[0].item() == 0) & (new_coords[:, 2] % new_ts[1].item() == 0) & \
57+
(new_coords[:, 3] % new_ts[2].item() == 0)
58+
]
59+
else:
60+
new_coords = new_coords[
61+
(new_coords[:, 1] % new_ts == 0) & (new_coords[:, 2] % new_ts == 0) & \
62+
(new_coords[:, 3] % new_ts == 0)
63+
]
64+
new_coords = new_coords[(new_coords[:, 1] >= 0)
65+
& (new_coords[:, 2] >= 0) &
66+
(new_coords[:, 3] >= 0)]
67+
# filter out duplicates
68+
new_coords = torch.unique(new_coords, dim=0)
69+
new_coords = new_coords[:, [1, 2, 3, 0]]
70+
return new_coords
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
3+
__all__ = ['squeeze_nmap']
4+
5+
6+
def squeeze_nmap(neighbor_map: torch.Tensor) -> torch.Tensor:
7+
idx_batch, idx_point = torch.where(neighbor_map != -1)
8+
map_converted = neighbor_map.view(-1)[idx_batch * neighbor_map.size(1) +
9+
idx_point]
10+
map_converted = torch.stack([map_converted, idx_point], dim=1)
11+
nmap_offset = torch.sum(neighbor_map != -1, 1)
12+
return map_converted.int().contiguous(), nmap_offset.int().contiguous()

0 commit comments

Comments
 (0)