|
7 | 7 |
|
8 | 8 | from __future__ import print_function |
9 | 9 |
|
| 10 | +from torch.hub import load_state_dict_from_url |
10 | 11 |
|
11 | | -def deeplabv2_resnet101(pretrained=False, **kwargs): |
12 | | - """ |
13 | | - DeepLab v2 model with ResNet-101 backbone |
14 | | - n_classes (int): the number of classes |
15 | | - """ |
| 12 | +model_url_root = "https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/" |
| 13 | +model_dict = { |
| 14 | + "cocostuff10k": ("deeplabv2_resnet101_msc-cocostuff10k-20000.pth", 182), |
| 15 | + "cocostuff164k": ("deeplabv2_resnet101_msc-cocostuff164k-100000.pth", 182), |
| 16 | + "voc12": ("deeplabv2_resnet101_msc-vocaug-20000.pth", 21), |
| 17 | +} |
16 | 18 |
|
17 | | - if pretrained: |
18 | | - raise NotImplementedError( |
19 | | - "Please download from " |
20 | | - "https://github.com/kazuto1011/deeplab-pytorch/tree/master#performance" |
21 | | - ) |
| 19 | + |
| 20 | +def deeplabv2_resnet101(pretrained=None, n_classes=182, scales=None): |
22 | 21 |
|
23 | 22 | from libs.models.deeplabv2 import DeepLabV2 |
24 | 23 | from libs.models.msc import MSC |
25 | 24 |
|
26 | | - base = DeepLabV2(n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24], **kwargs) |
27 | | - model = MSC(base=base, scales=[0.5, 0.75]) |
| 25 | + # Model parameters |
| 26 | + n_blocks = [3, 4, 23, 3] |
| 27 | + atrous_rates = [6, 12, 18, 24] |
| 28 | + if scales is None: |
| 29 | + scales = [0.5, 0.75] |
28 | 30 |
|
29 | | - return model |
| 31 | + base = DeepLabV2(n_classes=n_classes, n_blocks=n_blocks, atrous_rates=atrous_rates) |
| 32 | + model = MSC(base=base, scales=scales) |
30 | 33 |
|
| 34 | + # Load pretrained models |
| 35 | + if isinstance(pretrained, str): |
31 | 36 |
|
32 | | -if __name__ == "__main__": |
33 | | - import torch.hub |
| 37 | + assert pretrained in model_dict, list(model_dict.keys()) |
| 38 | + expected = model_dict[pretrained][1] |
| 39 | + error_message = "Expected: n_classes={}".format(expected) |
| 40 | + assert n_classes == expected, error_message |
34 | 41 |
|
35 | | - model = torch.hub.load( |
36 | | - "kazuto1011/deeplab-pytorch", |
37 | | - "deeplabv2_resnet101", |
38 | | - n_classes=182, |
39 | | - force_reload=True, |
40 | | - ) |
| 42 | + model_url = model_url_root + model_dict[pretrained][0] |
| 43 | + state_dict = load_state_dict_from_url(model_url) |
| 44 | + model.load_state_dict(state_dict) |
| 45 | + |
| 46 | + return model |
41 | 47 |
|
42 | | - print(model) |
|
0 commit comments