Skip to content

Commit b46f5ac

Browse files
authored
Don't reuse nn.ReLU modules (#239)
1 parent b4ae449 commit b46f5ac

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

clip/model.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,18 @@ def __init__(self, inplanes, planes, stride=1):
1616
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
1717
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
1818
self.bn1 = nn.BatchNorm2d(planes)
19+
self.relu1 = nn.ReLU(inplace=True)
1920

2021
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
2122
self.bn2 = nn.BatchNorm2d(planes)
23+
self.relu2 = nn.ReLU(inplace=True)
2224

2325
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
2426

2527
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
2628
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29+
self.relu3 = nn.ReLU(inplace=True)
2730

28-
self.relu = nn.ReLU(inplace=True)
2931
self.downsample = None
3032
self.stride = stride
3133

@@ -40,16 +42,16 @@ def __init__(self, inplanes, planes, stride=1):
4042
def forward(self, x: torch.Tensor):
4143
identity = x
4244

43-
out = self.relu(self.bn1(self.conv1(x)))
44-
out = self.relu(self.bn2(self.conv2(out)))
45+
out = self.relu1(self.bn1(self.conv1(x)))
46+
out = self.relu2(self.bn2(self.conv2(out)))
4547
out = self.avgpool(out)
4648
out = self.bn3(self.conv3(out))
4749

4850
if self.downsample is not None:
4951
identity = self.downsample(x)
5052

5153
out += identity
52-
out = self.relu(out)
54+
out = self.relu3(out)
5355
return out
5456

5557

@@ -106,12 +108,14 @@ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
106108
# the 3-layer stem
107109
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
108110
self.bn1 = nn.BatchNorm2d(width // 2)
111+
self.relu1 = nn.ReLU(inplace=True)
109112
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
110113
self.bn2 = nn.BatchNorm2d(width // 2)
114+
self.relu2 = nn.ReLU(inplace=True)
111115
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
112116
self.bn3 = nn.BatchNorm2d(width)
117+
self.relu3 = nn.ReLU(inplace=True)
113118
self.avgpool = nn.AvgPool2d(2)
114-
self.relu = nn.ReLU(inplace=True)
115119

116120
# residual layers
117121
self._inplanes = width # this is a *mutable* variable used during construction
@@ -134,8 +138,9 @@ def _make_layer(self, planes, blocks, stride=1):
134138

135139
def forward(self, x):
136140
def stem(x):
137-
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
138-
x = self.relu(bn(conv(x)))
141+
x = self.relu1(self.bn1(self.conv1(x)))
142+
x = self.relu2(self.bn2(self.conv2(x)))
143+
x = self.relu3(self.bn3(self.conv3(x)))
139144
x = self.avgpool(x)
140145
return x
141146

0 commit comments

Comments
 (0)