Skip to content

Commit 3806340

Browse files
committed
Using torch.nn.ReLU instead of fvdb.nn.ReLU; some import cleanups
Signed-off-by: Jonathan Swartz <jonathan@jswartz.info>
1 parent d84276b commit 3806340

File tree

1 file changed

+7
-17
lines changed

1 file changed

+7
-17
lines changed

fvdb/nn/simple_unet.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,13 @@
3030
such as 3D semantic segmentation, shape completion, or volumetric reconstruction.
3131
"""
3232

33-
import math
34-
from typing import Any, Sequence
3533

3634
import fvdb.nn as fvnn
3735
import torch
3836
import torch.nn as nn
39-
from fvdb.types import (
40-
NumericMaxRank1,
41-
NumericMaxRank2,
42-
ValueConstraint,
43-
to_Vec3i,
44-
to_Vec3iBroadcastable,
45-
)
46-
from torch.profiler import record_function
47-
48-
import fvdb
49-
from fvdb import ConvolutionPlan, Grid, GridBatch, JaggedTensor
37+
from fvdb.types import NumericMaxRank1
38+
39+
from fvdb import ConvolutionPlan, GridBatch, JaggedTensor
5040

5141
from .modules import fvnn_module
5242

@@ -87,7 +77,7 @@ def __init__(
8777

8878
self.conv = fvnn.SparseConv3d(in_channels, out_channels, kernel_size=kernel_size, stride=1, bias=False)
8979
self.batch_norm = fvnn.BatchNorm(out_channels, momentum=momentum)
90-
self.relu = fvnn.ReLU(inplace=True)
80+
self.relu = torch.nn.ReLU(inplace=True)
9181

9282
def extra_repr(self) -> str:
9383
return (
@@ -107,7 +97,7 @@ def forward(
10797
x = self.conv(data, plan)
10898
out_grid = plan.target_grid_batch
10999
x = self.batch_norm(x, out_grid)
110-
x = self.relu(x, out_grid)
100+
x = self.relu(x)
111101
return x
112102

113103

@@ -170,7 +160,7 @@ def __init__(
170160

171161
self.blocks = nn.ModuleList(layers)
172162

173-
self.final_relu = fvnn.ReLU(inplace=True)
163+
self.final_relu = torch.nn.ReLU(inplace=True)
174164

175165
def extra_repr(self) -> str:
176166
return (
@@ -199,7 +189,7 @@ def forward(
199189
data = block(data, plan)
200190

201191
data = data + residual
202-
data = self.final_relu(data, plan.target_grid_batch)
192+
data = self.final_relu(data)
203193

204194
return data
205195

0 commit comments

Comments
 (0)