diff --git a/torchsparse/nn/functional/conv/conv.py b/torchsparse/nn/functional/conv/conv.py index b8d4e13..a4504b8 100644 --- a/torchsparse/nn/functional/conv/conv.py +++ b/torchsparse/nn/functional/conv/conv.py @@ -135,6 +135,33 @@ def conv3d( (tensor_stride, kernel_size, stride, dilation) ) + + hashmap_keys, hashmap_vals = None, None + if kmap is None: + kmap = F.build_kernel_map( + coords, + feats.shape[0], + kernel_size, + stride, + padding, + hashmap_keys, + hashmap_vals, + input.spatial_range, + kmap_mode, + dataflow, + downsample_mode=config.downsample_mode, + training=training, + ifsort=config.ifsort, + split_mask_num=config.split_mask_num, + split_mask_num_bwd=config.split_mask_num_bwd, + ) + + hashmap = [kmap["hashmap_keys"], kmap["hashmap_vals"]] + + input._caches.kmaps[(input.stride, kernel_size, stride, dilation)] = kmap + input._caches.hashmaps[input.stride] = hashmap + + kmap = F.transpose_kernel_map( kmap, config.ifsort, @@ -153,6 +180,11 @@ def conv3d( if bias is not None: feats += bias + input._caches.cmaps[tensor_stride] = ( + kmap["coords"], + kmap.get("spatial_range"), + ) + output = SparseTensor( coords=input._caches.cmaps[tensor_stride][0], feats=feats,