Skip to content

Commit d3ff8ec

Browse files
authored
Merge pull request #96 from lucasb-eyer/parallel
Split Parallel into Parallel and RepeatInput.
2 parents 97ecd3b + 4427331 commit d3ff8ec

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

DeepFried2/containers/Parallel.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33

44
class Parallel(df.Container):
55
def symb_forward(self, symb_input):
6-
# TODO: Not sure if this polymorphism is any good!
7-
if isinstance(symb_input, (list, tuple)):
8-
assert len(symb_input) == len(self.modules), "If `{}` has multiple inputs, it should be the same amount as it has modules.".format(df.utils.typename(self))
9-
return tuple(module(symb_in) for module, symb_in in zip(self.modules, symb_input))
10-
else:
11-
return tuple(module(symb_input) for module in self.modules)
6+
assert isinstance(symb_input, (list, tuple)), "`{}` must have >1 inputs".format(df.utils.typename(self))
7+
assert len(symb_input) == len(self.modules), "`{}` should have the same number of inputs ({}) as modules ({}).".format(df.utils.typename(self), len(symb_input), len(self.modules))
8+
return tuple(module(symb_in) for module, symb_in in zip(self.modules, symb_input))
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import DeepFried2 as df
2+
3+
4+
class RepeatInput(df.Container):
5+
def symb_forward(self, symb_input):
6+
return tuple(module(symb_input) for module in self.modules)

DeepFried2/containers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .Sequential import Sequential
22
from .Parallel import Parallel
3+
from .RepeatInput import RepeatInput
34
from .Concat import Concat
45
from .StoreOut import StoreOut
56
from .ActiveIn import ActiveIn, InactiveIn, TrainingOnly, TestingOnly

DeepFried2/zoo/resnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def symb_forward(self, symb_inputs):
1313

1414
def block(nchan, fs=(3,3), body=None):
1515
return df.Sequential(
16-
df.Parallel(
16+
df.RepeatInput(
1717
df.Sequential(
1818
df.BatchNormalization(nchan), df.ReLU(),
1919
df.SpatialConvolutionCUDNN(nchan, nchan, fs, border='same', init=df.init.prelu(), bias=False),
@@ -28,7 +28,7 @@ def block(nchan, fs=(3,3), body=None):
2828

2929
def block_proj(nin, nout, fs=(3,3), body=None):
3030
return df.Sequential(
31-
df.Parallel(
31+
df.RepeatInput(
3232
df.Sequential(
3333
df.BatchNormalization(nin), df.ReLU(),
3434
df.SpatialConvolutionCUDNN(nin, nout, fs, border='same', init=df.init.prelu(), bias=False),

0 commit comments

Comments
 (0)