Skip to content

Commit a40e3a3

Browse files
LaserBitrohitgr7awaelchli
authored
Change the classifier input from 2048 to 1000. (#5232)
* Change the classifier input from 2048 to 1000. * Update docs for Imagenet example Thanks @rohitgr7 * Apply suggestions from code review Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent d5b3678 commit a40e3a3

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

docs/source/transfer_learning.rst

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,22 @@ Example: Imagenet (computer Vision)
5252

5353
class ImagenetTransferLearning(LightningModule):
5454
def __init__(self):
55+
super().__init__()
56+
5557
# init a pretrained resnet
56-
num_target_classes = 10
57-
self.feature_extractor = models.resnet50(pretrained=True)
58-
self.feature_extractor.eval()
58+
backbone = models.resnet50(pretrained=True)
59+
num_filters = backbone.fc.in_features
60+
layers = list(backbone.children())[:-1]
61+
self.feature_extractor = torch.nn.Sequential(*layers)
5962
6063
# use the pretrained model to classify cifar-10 (10 image classes)
61-
self.classifier = nn.Linear(2048, num_target_classes)
64+
num_target_classes = 10
65+
self.classifier = nn.Linear(num_filters, num_target_classes)
6266

6367
def forward(self, x):
64-
representations = self.feature_extractor(x)
68+
self.feature_extractor.eval()
69+
with torch.no_grad():
70+
representations = self.feature_extractor(x).flatten(1)
6571
x = self.classifier(representations)
6672
...
6773

0 commit comments

Comments
 (0)