44import torchsparse_backend
55from torch .autograd import Function
66from torch .cuda .amp import custom_fwd , custom_bwd
7- from torchsparse import *
8- from torchsparse .nn . functional . convert_neighbor_map import *
9- from torchsparse .nn . functional . downsample import *
10- from torchsparse .nn . functional . hash import *
11- from torchsparse . nn . functional . query import *
12- from torchsparse . utils . kernel_region import *
7+ from torchsparse import SparseTensor
8+ from torchsparse .nn import functional as spF
9+ from torchsparse .utils . helpers import make_tuple
10+ from torchsparse .utils . kernel import KernelRegion , KernelMapKey
11+
12+ from typing import Union , List , Tuple , Optional
1313
1414__all__ = ['conv3d' ]
1515
@@ -70,8 +70,15 @@ def backward(ctx, grad_out):
7070 features , kernel , neighbor_map , neighbor_offset , transpose = ctx .for_backwards
7171 K , c_in , c_out = kernel .size ()
7272 N_in = features .size (0 )
73- grad_features = torch .zeros (N_in , c_in , device = features .device , dtype = features .dtype )
74- grad_kernel = torch .zeros (K , c_in , c_out , device = kernel .device , dtype = features .dtype )
73+ grad_features = torch .zeros (N_in ,
74+ c_in ,
75+ device = features .device ,
76+ dtype = features .dtype )
77+ grad_kernel = torch .zeros (K ,
78+ c_in ,
79+ c_out ,
80+ device = kernel .device ,
81+ dtype = features .dtype )
7582
7683 if 'cuda' in str (features .device ):
7784 torchsparse_backend .sparseconv_backward (features , grad_features ,
@@ -87,18 +94,24 @@ def backward(ctx, grad_out):
8794sparseconv_op = SpConvolution .apply
8895
8996
90- def conv3d (inputs ,
91- kernel ,
92- kernel_size ,
93- bias = None ,
94- stride = 1 ,
95- dilation = 1 ,
96- transpose = False ):
97+ def conv3d (inputs : SparseTensor ,
98+ kernel : torch . Tensor ,
99+ kernel_size : Union [ int , List [ int ], Tuple [ int , int , int ]] ,
100+ bias : Optional [ torch . Tensor ] = None ,
101+ stride : Union [ int , List [ int ], Tuple [ int , int , int ]] = 1 ,
102+ dilation : Union [ int , List [ int ], Tuple [ int , int , int ]] = 1 ,
103+ transpose : bool = False ) -> SparseTensor :
97104 features = inputs .F
98105 coords = inputs .C
99106 cur_stride = inputs .s
100107
101- if kernel_size == 1 and stride == 1 and dilation == 1 :
108+ # convert to hashable types
109+ kernel_size = make_tuple (kernel_size )
110+ stride = make_tuple (stride )
111+ dilation = make_tuple (dilation )
112+
113+ if kernel_size == (1 , 1 , 1 ) and stride == (1 , 1 , 1 ) and dilation == (1 , 1 ,
114+ 1 ):
102115 output_features = features .matmul (kernel )
103116 if bias is not None :
104117 output_features += bias
@@ -107,34 +120,37 @@ def conv3d(inputs,
107120 output_tensor .kernel_maps = inputs .kernel_maps
108121 output_tensor .check ()
109122 elif not transpose :
110- kernel_map = inputs . kernel_maps . get (
111- 'k%s_os%d_s%d_d%d' % ( kernel_size , cur_stride , stride , dilation ),
112- None )
123+ kernel_map_key = KernelMapKey ( kernel_size , cur_stride , stride ,
124+ dilation )
125+ kernel_map = inputs . kernel_maps . get ( kernel_map_key , None )
113126
114- if stride > 1 :
127+ if any ( x > 1 for x in stride ) :
115128 # do downsample
116129 kRegion = KernelRegion (kernel_size = kernel_size ,
117130 tensor_stride = cur_stride )
118131 kOffset = kRegion .get_kernel_offset ().to (features .device )
119- new_coords = spdownsample (coords , stride * cur_stride )
120- hash_query = sphash (new_coords , kOffset )
121- hash_target = sphash (coords )
122- idx_query = sphashquery (hash_query , hash_target )
123- idx_query = list (convert_neighbor_map_gpu (idx_query ))
132+ new_coords = spF .spdownsample (coords , stride , kernel_size ,
133+ cur_stride )
134+ hash_query = spF .sphash (new_coords , kOffset )
135+ hash_target = spF .sphash (coords )
136+ idx_query = spF .sphashquery (hash_query , hash_target )
137+ idx_query = list (spF .squeeze_nmap (idx_query ))
124138 idx_query [1 ] = idx_query [1 ].to ('cpu' )
125139 sizes = (features .shape [0 ], new_coords .shape [0 ])
126140 output_features = sparseconv_op (features , kernel , idx_query [0 ],
127141 idx_query [1 ], sizes , transpose )
128142 if bias is not None :
129143 output_features += bias
130- output_tensor = SparseTensor (output_features , new_coords ,
131- cur_stride * stride )
144+ output_tensor = SparseTensor (
145+ output_features , new_coords ,
146+ [a * b for a , b in zip (cur_stride , stride )])
132147 output_tensor .coord_maps = copy .deepcopy (inputs .coord_maps )
133148 output_tensor .check ()
149+
150+ kernel_map_key = KernelMapKey (kernel_size , cur_stride , stride ,
151+ dilation )
134152 output_tensor .kernel_maps = copy .deepcopy (inputs .kernel_maps )
135- output_tensor .kernel_maps ['k%s_os%d_s%d_d%d' %
136- (kernel_size , cur_stride , stride ,
137- dilation )] = idx_query + [sizes ]
153+ output_tensor .kernel_maps [kernel_map_key ] = idx_query + [sizes ]
138154
139155 else :
140156 if kernel_map is None :
@@ -144,10 +160,10 @@ def conv3d(inputs,
144160 kOffset = kRegion .get_kernel_offset ().to (features .device )
145161 except :
146162 raise
147- hash_query = sphash (coords , kOffset )
148- hash_target = sphash (coords )
149- idx_query = sphashquery (hash_query , hash_target )
150- idx_query = list (convert_neighbor_map_gpu (idx_query ))
163+ hash_query = spF . sphash (coords , kOffset )
164+ hash_target = spF . sphash (coords )
165+ idx_query = spF . sphashquery (hash_query , hash_target )
166+ idx_query = list (spF . squeeze_nmap (idx_query ))
151167 idx_query [1 ] = idx_query [1 ].to ('cpu' )
152168 sizes = (features .shape [0 ], features .shape [0 ])
153169 output_features = sparseconv_op (features , kernel , idx_query [0 ],
@@ -159,9 +175,9 @@ def conv3d(inputs,
159175 output_tensor .coord_maps = inputs .coord_maps
160176 output_tensor .check ()
161177 output_tensor .kernel_maps = copy .deepcopy (inputs .kernel_maps )
162- output_tensor . kernel_maps [ 'k%s_os%d_s%d_d%d' %
163- ( kernel_size , cur_stride , stride ,
164- dilation ) ] = idx_query + [sizes ]
178+ kernel_map_key = KernelMapKey ( kernel_size , cur_stride , stride ,
179+ dilation )
180+ output_tensor . kernel_maps [ kernel_map_key ] = idx_query + [sizes ]
165181 else :
166182 output_features = sparseconv_op (features , kernel ,
167183 kernel_map [0 ], kernel_map [1 ],
@@ -176,17 +192,24 @@ def conv3d(inputs,
176192
177193 else :
178194 # do upsample
179- original_stride = int (cur_stride / stride )
180- kernel_map = inputs .kernel_maps .get (
181- 'k%s_os%d_s%d_d%d' %
182- (kernel_size , original_stride , stride , dilation ), None )
195+
196+ original_stride = tuple (
197+ [int (a / b ) for a , b in zip (cur_stride , stride )])
198+
199+ kernel_map_key = KernelMapKey (kernel_size , original_stride , stride ,
200+ dilation )
201+ kernel_map = inputs .kernel_maps .get (kernel_map_key , None )
202+ assert kernel_map is not None , f'{ kernel_map_key } does not exist.'
183203 output_features = sparseconv_op (features , kernel , kernel_map [0 ],
184204 kernel_map [1 ], kernel_map [2 ],
185205 transpose )
186206 if bias is not None :
187207 output_features += bias
188- output_tensor = SparseTensor (output_features ,
189- inputs .coord_maps [original_stride ],
208+
209+ cur_coords = inputs .coord_maps .get (original_stride , None )
210+ assert cur_coords is not None , f'{ original_stride } not in coord maps.'
211+
212+ output_tensor = SparseTensor (output_features , cur_coords ,
190213 original_stride )
191214 output_tensor .coord_maps = inputs .coord_maps
192215 output_tensor .check ()
0 commit comments