Skip to content

Commit b3c83a1

Browse files
committed
Update rfnet.py
1 parent 8cbac0c commit b3c83a1

File tree

1 file changed

+27
-1
lines changed
  • examples/robot/lifelong_learning_bench/semantic-segmentation/testalgorithms/rfnet/RFNet/models

1 file changed

+27
-1
lines changed
Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,27 @@
1-
from logger import logger
1+
import torch.nn as nn
2+
from itertools import chain # 串联多个迭代对象
3+
from logger import logger
4+
from .util import _BNReluConv, upsample
5+
6+
7+
class RFNet(nn.Module):
8+
def __init__(self, backbone, num_classes, use_bn=True):
9+
super(RFNet, self).__init__()
10+
self.backbone = backbone
11+
self.num_classes = num_classes
12+
logger.info(self.backbone.num_features)
13+
self.logits = _BNReluConv(self.backbone.num_features, self.num_classes, batch_norm=use_bn)
14+
15+
def forward(self, rgb_inputs, depth_inputs = None):
16+
x, additional = self.backbone(rgb_inputs, depth_inputs)
17+
logits = self.logits.forward(x)
18+
logits_upsample = upsample(logits, rgb_inputs.shape[2:])
19+
#print(logits_upsample.size)
20+
return logits_upsample
21+
22+
23+
def random_init_params(self):
24+
return chain(*([self.logits.parameters(), self.backbone.random_init_params()]))
25+
26+
def fine_tune_params(self):
27+
return self.backbone.fine_tune_params()

0 commit comments

Comments
 (0)