diff --git a/models/cswin.py b/models/cswin.py index 9a74de1..ae1b2e7 100644 --- a/models/cswin.py +++ b/models/cswin.py @@ -326,16 +326,23 @@ def no_weight_decay(self): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): - if self.num_classes != num_classes: - print ('reset head to', num_classes) + def reset_classifier(self, num_classes, force=False, to_gpu=False): + if self.num_classes != num_classes or force: + print("reset head to", num_classes) self.num_classes = num_classes - self.head = nn.Linear(self.out_dim, num_classes) if num_classes > 0 else nn.Identity() - self.head = self.head.cuda() - trunc_normal_(self.head.weight, std=.02) + self.head = ( + nn.Linear(self.head.in_features, num_classes) + if num_classes > 0 + else nn.Identity() + ) + if to_gpu: + self.head = self.head.cuda() + # init new head + trunc_normal_(self.head.weight, std=0.02) if self.head.bias is not None: nn.init.constant_(self.head.bias, 0) + def forward_features(self, x): B = x.shape[0] x = self.stage1_conv_embed(x)