3030such as 3D semantic segmentation, shape completion, or volumetric reconstruction.
3131"""
3232
33- import math
34- from typing import Any , Sequence
3533
3634import fvdb .nn as fvnn
3735import torch
3836import torch .nn as nn
39- from fvdb .types import (
40- NumericMaxRank1 ,
41- NumericMaxRank2 ,
42- ValueConstraint ,
43- to_Vec3i ,
44- to_Vec3iBroadcastable ,
45- )
46- from torch .profiler import record_function
47-
48- import fvdb
49- from fvdb import ConvolutionPlan , Grid , GridBatch , JaggedTensor
37+ from fvdb .types import NumericMaxRank1
38+
39+ from fvdb import ConvolutionPlan , GridBatch , JaggedTensor
5040
5141from .modules import fvnn_module
5242
@@ -87,7 +77,7 @@ def __init__(
8777
8878 self .conv = fvnn .SparseConv3d (in_channels , out_channels , kernel_size = kernel_size , stride = 1 , bias = False )
8979 self .batch_norm = fvnn .BatchNorm (out_channels , momentum = momentum )
90- self .relu = fvnn .ReLU (inplace = True )
80+ self .relu = torch . nn .ReLU (inplace = True )
9181
9282 def extra_repr (self ) -> str :
9383 return (
@@ -107,7 +97,7 @@ def forward(
10797 x = self .conv (data , plan )
10898 out_grid = plan .target_grid_batch
10999 x = self .batch_norm (x , out_grid )
110- x = self .relu (x , out_grid )
100+ x = self .relu (x )
111101 return x
112102
113103
@@ -170,7 +160,7 @@ def __init__(
170160
171161 self .blocks = nn .ModuleList (layers )
172162
173- self .final_relu = fvnn .ReLU (inplace = True )
163+ self .final_relu = torch . nn .ReLU (inplace = True )
174164
175165 def extra_repr (self ) -> str :
176166 return (
@@ -199,7 +189,7 @@ def forward(
199189 data = block (data , plan )
200190
201191 data = data + residual
202- data = self .final_relu (data , plan . target_grid_batch )
192+ data = self .final_relu (data )
203193
204194 return data
205195
0 commit comments