Skip to content

Commit d0d215a

Browse files
authored
Merge pull request #90 from lucasb-eyer/resblock
Add ResNet v2 block creation utils.
2 parents 7ec8ff2 + ee30279 commit d0d215a

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

DeepFried2/zoo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .download import download
2+
from . import resnet
23
from . import vgg16
34
from . import vgg19

DeepFried2/zoo/resnet.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import DeepFried2 as df
2+
3+
4+
class Add(df.Module):
5+
def symb_forward(self, symb_inputs):
6+
assert isinstance(symb_inputs, (list, tuple)), "Input to `Add` must be multiple tensors."
7+
8+
s = symb_inputs[0]
9+
for x in symb_inputs[1:]:
10+
s = s + x
11+
return s
12+
13+
14+
def block(nchan, fs=(3,3), body=None):
15+
return df.Sequential(
16+
df.Parallel(
17+
df.Sequential(
18+
df.BatchNormalization(nchan), df.ReLU(),
19+
df.SpatialConvolutionCUDNN(nchan, nchan, fs, border='same', init=df.init.prelu(), bias=False),
20+
df.BatchNormalization(nchan), df.ReLU(),
21+
df.SpatialConvolutionCUDNN(nchan, nchan, fs, border='same', init=df.init.prelu(), bias=False)
22+
) if body is None else body,
23+
df.Identity()
24+
),
25+
Add()
26+
)
27+
28+
29+
def block_proj(nin, nout, fs=(3,3), body=None):
30+
return df.Sequential(
31+
df.Parallel(
32+
df.Sequential(
33+
df.BatchNormalization(nin), df.ReLU(),
34+
df.SpatialConvolutionCUDNN(nin, nout, fs, border='same', init=df.init.prelu(), bias=False),
35+
df.BatchNormalization(nout), df.ReLU(),
36+
df.SpatialConvolutionCUDNN(nout, nout, fs, border='same', init=df.init.prelu(), bias=False)
37+
) if body is None else body,
38+
df.SpatialConvolutionCUDNN(nin, nout, (1,)*len(fs)),
39+
),
40+
Add()
41+
)

0 commit comments

Comments
 (0)