|
| 1 | +import paddle |
| 2 | +import paddle.nn as nn |
| 3 | + |
| 4 | +from ppsci.arch import base |
| 5 | + |
| 6 | + |
| 7 | +class ResNetBlock(nn.Layer): |
| 8 | + def __init__(self, in_channels, out_channels, stride=1): |
| 9 | + super().__init__() |
| 10 | + self.conv1 = nn.Conv2D(in_channels, out_channels, 3, stride, 1) |
| 11 | + self.bn1 = nn.BatchNorm2D(out_channels) |
| 12 | + self.relu = nn.ReLU() |
| 13 | + self.conv2 = nn.Conv2D(out_channels, out_channels, 3, 1, 1) |
| 14 | + self.bn2 = nn.BatchNorm2D(out_channels) |
| 15 | + if stride != 1 or in_channels != out_channels: |
| 16 | + self.downsample = nn.Sequential( |
| 17 | + nn.Conv2D(in_channels, out_channels, 1, stride), |
| 18 | + nn.BatchNorm2D(out_channels), |
| 19 | + ) |
| 20 | + else: |
| 21 | + self.downsample = None |
| 22 | + |
| 23 | + def forward(self, x): |
| 24 | + identity = x |
| 25 | + out = self.relu(self.bn1(self.conv1(x))) |
| 26 | + out = self.bn2(self.conv2(out)) |
| 27 | + if self.downsample is not None: |
| 28 | + identity = self.downsample(x) |
| 29 | + out += identity |
| 30 | + out = self.relu(out) |
| 31 | + return out |
| 32 | + |
| 33 | + |
| 34 | +class ResNet(base.Arch): |
| 35 | + """ |
| 36 | + PaddleScience风格的ResNet实现,支持自定义输入输出、层数、特征提取等。 |
| 37 | + """ |
| 38 | + |
| 39 | + def __init__( |
| 40 | + self, |
| 41 | + input_keys, |
| 42 | + output_keys, |
| 43 | + num_blocks=(2, 2, 2, 2), # ResNet18默认 |
| 44 | + num_classes=1, |
| 45 | + in_channels=3, |
| 46 | + base_channels=64, |
| 47 | + **kwargs |
| 48 | + ): |
| 49 | + super().__init__() |
| 50 | + self.input_keys = input_keys |
| 51 | + self.output_keys = output_keys |
| 52 | + |
| 53 | + self.conv1 = nn.Conv2D(in_channels, base_channels, 7, 2, 3) |
| 54 | + self.bn1 = nn.BatchNorm2D(base_channels) |
| 55 | + self.relu = nn.ReLU() |
| 56 | + self.maxpool = nn.MaxPool2D(3, 2, 1) |
| 57 | + |
| 58 | + self.layer1 = self._make_layer(base_channels, base_channels, num_blocks[0]) |
| 59 | + self.layer2 = self._make_layer( |
| 60 | + base_channels, base_channels * 2, num_blocks[1], stride=2 |
| 61 | + ) |
| 62 | + self.layer3 = self._make_layer( |
| 63 | + base_channels * 2, base_channels * 4, num_blocks[2], stride=2 |
| 64 | + ) |
| 65 | + self.layer4 = self._make_layer( |
| 66 | + base_channels * 4, base_channels * 8, num_blocks[3], stride=2 |
| 67 | + ) |
| 68 | + |
| 69 | + self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) |
| 70 | + self.fc = nn.Linear(base_channels * 8, num_classes) |
| 71 | + |
| 72 | + def _make_layer(self, in_channels, out_channels, blocks, stride=1): |
| 73 | + layers = [ResNetBlock(in_channels, out_channels, stride)] |
| 74 | + for _ in range(1, blocks): |
| 75 | + layers.append(ResNetBlock(out_channels, out_channels)) |
| 76 | + return nn.Sequential(*layers) |
| 77 | + |
| 78 | + def forward(self, x): |
| 79 | + # x: dict, 取input_keys |
| 80 | + if isinstance(x, dict): |
| 81 | + x = x[self.input_keys[0]] |
| 82 | + x = self.conv1(x) |
| 83 | + x = self.bn1(x) |
| 84 | + x = self.relu(x) |
| 85 | + x = self.maxpool(x) |
| 86 | + x = self.layer1(x) |
| 87 | + x = self.layer2(x) |
| 88 | + x = self.layer3(x) |
| 89 | + x = self.layer4(x) |
| 90 | + x = self.avgpool(x) |
| 91 | + x = paddle.flatten(x, 1) |
| 92 | + x = self.fc(x) |
| 93 | + return x |
0 commit comments