Skip to content

Commit 081fc4a

Browse files
Add option to do sum pooling in the CNN encoder
1 parent b018cbe commit 081fc4a

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

spine/model/layer/cnn/encoder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ def __init__(self, coord_conv=False, pool_mode='avg',
4646

4747
if pool_mode == 'avg':
4848
# Average pooling
49-
self.pool = ME.MinkowskiGlobalPooling()
49+
self.pool = ME.MinkowskiGlobalAvgPooling()
50+
51+
if pool_mode == 'sum':
52+
# Sum pooling
53+
self.pool = ME.MinkowskiGlobalSumPooling()
5054

5155
elif pool_mode == 'max':
5256
# Max pooling
53-
self.pool = torch.nn.Sequential(
54-
ME.MinkowskiMaxPooling(
55-
final_tensor_shape, stride=final_tensor_shape,
56-
dimension=self.dim),
57-
ME.MinkowskiGlobalPooling())
57+
self.pool = ME.MinkowskiGlobalMaxPooling()
5858

5959
elif pool_mode == 'conv':
6060
# Strided convolution
@@ -69,7 +69,7 @@ def __init__(self, coord_conv=False, pool_mode='avg',
6969
else:
7070
raise ValueError(
7171
f"Pooling mode not recognized: {self.pool_mode}. Must be "
72-
"one of 'avg', 'max' or 'conv'")
72+
"one of 'avg', 'sum', 'max' or 'conv'")
7373

7474
# Initialize the final linear layer
7575
self.feature_size = feature_size

0 commit comments

Comments
 (0)