diff --git a/lib/models/seg_hrnet_ocr.py b/lib/models/seg_hrnet_ocr.py index ed9df629..8ab88b44 100644 --- a/lib/models/seg_hrnet_ocr.py +++ b/lib/models/seg_hrnet_ocr.py @@ -62,7 +62,7 @@ def forward(self, feats, probs): feats = feats.permute(0, 2, 1) # batch x hw x c probs = F.softmax(self.scale * probs, dim=2)# batch x k x hw ocr_context = torch.matmul(probs, feats)\ - .permute(0, 2, 1).unsqueeze(3)# batch x k x c + .permute(0, 2, 1).unsqueeze(3) # batch x c x k x 1 return ocr_context