@@ -16,16 +16,18 @@ def __init__(self, inplanes, planes, stride=1):
1616 # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
1717 self .conv1 = nn .Conv2d (inplanes , planes , 1 , bias = False )
1818 self .bn1 = nn .BatchNorm2d (planes )
19+ self .relu1 = nn .ReLU (inplace = True )
1920
2021 self .conv2 = nn .Conv2d (planes , planes , 3 , padding = 1 , bias = False )
2122 self .bn2 = nn .BatchNorm2d (planes )
23+ self .relu2 = nn .ReLU (inplace = True )
2224
2325 self .avgpool = nn .AvgPool2d (stride ) if stride > 1 else nn .Identity ()
2426
2527 self .conv3 = nn .Conv2d (planes , planes * self .expansion , 1 , bias = False )
2628 self .bn3 = nn .BatchNorm2d (planes * self .expansion )
29+ self .relu3 = nn .ReLU (inplace = True )
2730
28- self .relu = nn .ReLU (inplace = True )
2931 self .downsample = None
3032 self .stride = stride
3133
@@ -40,16 +42,16 @@ def __init__(self, inplanes, planes, stride=1):
4042 def forward (self , x : torch .Tensor ):
4143 identity = x
4244
43- out = self .relu (self .bn1 (self .conv1 (x )))
44- out = self .relu (self .bn2 (self .conv2 (out )))
45+ out = self .relu1 (self .bn1 (self .conv1 (x )))
46+ out = self .relu2 (self .bn2 (self .conv2 (out )))
4547 out = self .avgpool (out )
4648 out = self .bn3 (self .conv3 (out ))
4749
4850 if self .downsample is not None :
4951 identity = self .downsample (x )
5052
5153 out += identity
52- out = self .relu (out )
54+ out = self .relu3 (out )
5355 return out
5456
5557
@@ -106,12 +108,14 @@ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
106108 # the 3-layer stem
107109 self .conv1 = nn .Conv2d (3 , width // 2 , kernel_size = 3 , stride = 2 , padding = 1 , bias = False )
108110 self .bn1 = nn .BatchNorm2d (width // 2 )
111+ self .relu1 = nn .ReLU (inplace = True )
109112 self .conv2 = nn .Conv2d (width // 2 , width // 2 , kernel_size = 3 , padding = 1 , bias = False )
110113 self .bn2 = nn .BatchNorm2d (width // 2 )
114+ self .relu2 = nn .ReLU (inplace = True )
111115 self .conv3 = nn .Conv2d (width // 2 , width , kernel_size = 3 , padding = 1 , bias = False )
112116 self .bn3 = nn .BatchNorm2d (width )
117+ self .relu3 = nn .ReLU (inplace = True )
113118 self .avgpool = nn .AvgPool2d (2 )
114- self .relu = nn .ReLU (inplace = True )
115119
116120 # residual layers
117121 self ._inplanes = width # this is a *mutable* variable used during construction
@@ -134,8 +138,9 @@ def _make_layer(self, planes, blocks, stride=1):
134138
135139 def forward (self , x ):
136140 def stem (x ):
137- for conv , bn in [(self .conv1 , self .bn1 ), (self .conv2 , self .bn2 ), (self .conv3 , self .bn3 )]:
138- x = self .relu (bn (conv (x )))
141+ x = self .relu1 (self .bn1 (self .conv1 (x )))
142+ x = self .relu2 (self .bn2 (self .conv2 (x )))
143+ x = self .relu3 (self .bn3 (self .conv3 (x )))
139144 x = self .avgpool (x )
140145 return x
141146
0 commit comments