Skip to content

Commit bfd880a

Browse files
committed
Merge pull request #53 from lucasb-eyer/param-refactor
Introduce `Param` class.
2 parents eddce17 + df10d88 commit bfd880a

File tree

20 files changed

+179
-188
lines changed

20 files changed

+179
-188
lines changed

DeepFried2/Container.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import DeepFried2 as df
2+
from collections import OrderedDict as _OrderedDict
3+
from itertools import chain as _chain
24

35

46
class Container(df.Module):
@@ -19,24 +21,17 @@ def training(self):
1921
for module in self.modules:
2022
module.training()
2123

22-
def parameters(self):
23-
params, grads = [], []
24+
def parameters(self, *a, **kw):
25+
params = _chain.from_iterable(m.parameters(*a, **kw) for m in self.modules)
2426

25-
for module in self.modules:
26-
mod_params, mod_grads = module.parameters()
27-
params += mod_params
28-
grads += mod_grads
29-
30-
return params, grads
31-
32-
def may_decay(self):
33-
return sum((m.may_decay() for m in self.modules), [])
27+
# We actually need to remove duplicates from the list of parameters
28+
# (and their corresponding gradients) in order to support reusing
29+
# the same layer at multiple places in the graph,
30+
# e.g. do weight sharing.
31+
return list(_OrderedDict.fromkeys(params).keys())
3432

3533
def get_stat_updates(self):
36-
stat_updates = []
37-
for module in self.modules:
38-
stat_updates += module.get_stat_updates()
39-
return stat_updates
34+
return _chain.from_iterable(m.get_stat_updates() for m in self.modules)
4035

4136
def add(self, *modules):
4237
for m in modules:

DeepFried2/Module.py

Lines changed: 22 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import DeepFried2 as df
22
from DeepFried2.utils import make_tensor_or_tensors, aslist
33

4-
from collections import OrderedDict as _OrderedDict
54
import numpy as _np
65

76
class Module:
@@ -19,44 +18,24 @@ def __init__(self):
1918
#def __hash__(self):
2019
# raise NotImplementedError("You *need* to reimplement hash, even if it's just python's default. See the documentation for more info.")
2120

21+
def _addparam(self, *a, **kw):
22+
# Add it here because many don't even have params. This avoids misuse.
23+
if not hasattr(self, '_params'):
24+
self._params = []
25+
26+
param = df.Param(*a, **kw)
27+
self._params.append(param)
28+
return param
29+
2230
def zero_grad_parameters(self):
23-
_, grads = self.unique_parameters() # Here, it's just a matter of performance. But even then, not really.
24-
for grad in grads:
25-
grad.set_value(_np.zeros_like(grad.get_value()))
26-
27-
def parameters(self):
28-
params, grads = [], []
29-
30-
if hasattr(self, 'weight'):
31-
assert hasattr(self, 'grad_weight'), "The layer {} has a `weight` variable but no `grad_weight`, you probably forget to implement it.".format(df.utils.typename(self))
32-
params += [self.weight]
33-
grads += [self.grad_weight]
34-
35-
if hasattr(self, 'bias'):
36-
assert hasattr(self, 'grad_bias'), "The layer {} has a `bias` variable but no `grad_bias`, you probably forget to implement it.".format(df.utils.typename(self))
37-
params += [self.bias]
38-
grads += [self.grad_bias]
39-
40-
return params, grads
41-
42-
def unique_parameters(self):
43-
# We actually need to remove duplicates from the list of parameters
44-
# (and their corresponding gradients) in order to support reusing
45-
# the same layer at multiple places in the graph,
46-
# e.g. do weight sharing.
47-
params, grads = self.parameters()
48-
return (
49-
list(_OrderedDict.fromkeys(params).keys()),
50-
list(_OrderedDict.fromkeys(grads).keys()),
51-
)
52-
53-
def may_decay(self):
54-
flags = []
55-
if hasattr(self, 'weight'):
56-
flags += [True]
57-
if hasattr(self, 'bias'):
58-
flags += [False]
59-
return flags
31+
for p in self.parameters(trainable_only=True):
32+
p.zero_grad()
33+
34+
def parameters(self, trainable_only=False):
35+
params = getattr(self, '_params', [])
36+
if trainable_only:
37+
params = [p for p in params if p.trainable()]
38+
return params
6039

6140
def evaluate(self):
6241
self.training_mode = False
@@ -85,10 +64,10 @@ def accumulate_gradients(self, data_in, data_tgt, loss):
8564
symb_out = self.symb_forward(symb_in)
8665
symb_err = loss.full_symb_forward(symb_out, symb_tgt)
8766

88-
params, grads = self.unique_parameters()
89-
symb_grads = df.th.grad(cost=symb_err, wrt=params)
67+
params = self.parameters(trainable_only=True)
68+
symb_grads = df.th.grad(cost=symb_err, wrt=[p.param for p in params])
69+
grads_updates = [(p.grad, p.grad + symb_grad) for p, symb_grad in zip(params, symb_grads)]
9070

91-
grads_updates = [(grad, grad + symb_grad) for grad, symb_grad in zip(grads, symb_grads)]
9271
self._fn_accum_grads[self.training_mode] = df.th.function(
9372
inputs=aslist(symb_in) + aslist(symb_tgt),
9473
outputs=symb_err,
@@ -151,8 +130,8 @@ def clear(self):
151130
self._fn_accum_stats.clear()
152131

153132
def __getstate__(self):
154-
return [p.get_value() for p in self.unique_parameters()[0]]
133+
return [p.get_value() for p in self.parameters()]
155134

156135
def __setstate__(self, state):
157-
for p, s in zip(self.unique_parameters()[0], state):
136+
for p, s in zip(self.parameters(), state):
158137
p.set_value(s)

DeepFried2/Optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ def __init__(self, **hyperparams):
1010
def update_parameters(self, model):
1111

1212
if model not in self.states:
13-
params, grads = model.unique_parameters()
1413
# TODO: Not only scalar, e.g. Adam might profit from integer t
1514
hyperparams = {name: df.T.scalar(name) for name in self.hyperparams}
15+
params, grads = zip(*[(p.param, p.grad) for p in model.parameters(trainable_only=True)])
1616
updates = self.get_updates(params, grads, **hyperparams)
1717
self.states[model] = df.th.function(
1818
inputs=list(hyperparams.values()),

DeepFried2/Param.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import DeepFried2 as df
2+
import numpy as _np
3+
4+
5+
class Param:
6+
7+
def __init__(self, shape, init, fan=None, name=None, learn=True, decay=True, dtype=df.floatX, **kw):
8+
self.init = init
9+
self.shape = shape
10+
self.fan = fan
11+
self.decay = decay
12+
13+
val = init(self.shape, self.fan).astype(dtype)
14+
self.param = df.th.shared(val, name=name, **kw)
15+
16+
if learn:
17+
grad_name = 'grad_' + name if name is not None else None
18+
self.grad = df.th.shared(_np.zeros_like(val), name=grad_name, **kw)
19+
else:
20+
self.grad = None
21+
22+
def get_value(self):
23+
return self.param.get_value()
24+
25+
def set_value(self, val):
26+
self.param.set_value(val)
27+
28+
def reinit(self):
29+
self.param.set_value(self.init(self.shape, self.fan).astype(self.param.dtype))
30+
31+
def zero_grad(self):
32+
self.grad.set_value(_np.zeros(self.shape, self.param.dtype))
33+
34+
def may_decay(self):
35+
return self.grad is not None and self.decay
36+
37+
def trainable(self):
38+
return self.grad is not None

DeepFried2/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import DeepFried2.init as init
66

7+
from .Param import Param
8+
79
from .Module import Module
810
from .layers import *
911

DeepFried2/criteria/WeightDecay.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,20 @@
55

66

77
class L1WeightDecay:
8-
def __init__(self, *containers):
9-
self.containers = containers
8+
def __init__(self, *modules):
9+
self.modules = modules
1010

1111
def symb_forward(self):
12-
return sum(df.T.sum(abs(p)) for p in collect_decayable_params(*self.containers))
12+
return sum(df.T.sum(abs(p)) for p in _collect_decayable_params(self.modules))
1313

1414

1515
class L2WeightDecay:
16-
def __init__(self, *containers):
17-
self.containers = containers
16+
def __init__(self, *modules):
17+
self.modules = modules
1818

1919
def symb_forward(self):
20-
return sum(df.T.sum(p**2) for p in collect_decayable_params(*self.containers))
20+
return sum(df.T.sum(p**2) for p in _collect_decayable_params(self.modules))
2121

2222

23-
def collect_decayable_params(*containers):
24-
decay_params = []
25-
for c in containers:
26-
params, _ = c.unique_parameters() # TODO: unique or non-unique?
27-
may = c.may_decay()
28-
29-
assert len(params) == len(may), "Possible implementation bug in `{}.may_decay()`: {} parameters, but {} decay infos.".format(df.utils.typename(c), len(params), len(may))
30-
31-
decay_params += [p for p,m in zip(params, may) if may]
32-
return decay_params
23+
def _collect_decayable_params(modules):
24+
return [p.param for c in modules for p in c.parameters() if p.may_decay()]
Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import DeepFried2 as df
2-
from DeepFried2.utils import create_param_and_grad, expand
2+
from DeepFried2.utils import expand
33
from theano.sandbox.cuda.basic_ops import gpu_contiguous, gpu_alloc_empty
44
from theano.sandbox.cuda import dnn
55

66
import numpy as np
77

88

99
class BackwardsConvolutionCUDNN(df.Module):
10-
def __init__(self, nchan_in, nchan_out, filter_size, stride=1, border=0, mode='cross', with_bias=True, initW=df.init.xavier(), initB=df.init.const(0)):
10+
def __init__(self, nchan_in, nchan_out, filter_size, stride=1, border=0, mode='cross', init=df.init.xavier(), bias=df.init.const(0)):
1111
# mode='cross' is the default in Lasagne[1], Torch[2], matConvNet[3], Caffee[4].
1212
#
1313
# 1: https://github.com/Lasagne/Lasagne/blob/63d44a0d/lasagne/layers/dnn.py#L299
@@ -19,7 +19,6 @@ def __init__(self, nchan_in, nchan_out, filter_size, stride=1, border=0, mode='c
1919
self.nchan_out = nchan_out
2020
self.filter_size = filter_size
2121
self.mode = mode
22-
self.with_bias = with_bias
2322
self.stride = expand(stride, len(filter_size), 'stride')
2423
self.border = expand(border, len(filter_size), 'border')
2524

@@ -30,27 +29,29 @@ def __init__(self, nchan_in, nchan_out, filter_size, stride=1, border=0, mode='c
3029

3130
w_shape = (nchan_in, nchan_out) + self.filter_size
3231
w_fan = (np.prod(self.filter_size)*nchan_out, np.prod(self.filter_size)*nchan_in)
32+
w_name = ('Wconv_{},{}@{}' + 'x{}'*(len(w_shape) - 3)).format(*w_shape)
33+
self.W = self._addparam(w_shape, init, fan=w_fan, name=w_name)
3334

34-
param_name = 'Wconv_{},{}@{}' + 'x{}'*(len(w_shape) - 3)
35-
self.weight, self.grad_weight = create_param_and_grad(w_shape, initW, fan=w_fan, name=param_name.format(*w_shape))
36-
if self.with_bias:
37-
self.bias, self.grad_bias = create_param_and_grad(nchan_out, initB, name='bconv_{}'.format(nchan_out))
35+
if bias not in (None, False):
36+
self.b = self._addparam(nchan_out, bias, decay=False, name='bconv_{}'.format(nchan_out))
37+
else:
38+
self.b = None
3839

3940

4041
def symb_forward(self, symb_input):
4142
""" creates dummy forward conv and uses its gradient as backwards pass """
4243
""" This code is mostly taken from https://github.com/Newmu/dcgan_code/blob/master/lib/ops.py """
4344
img = gpu_contiguous(symb_input)
44-
kerns = gpu_contiguous(self.weight)
45+
kerns = gpu_contiguous(self.W.param)
4546

4647
alloc_shape = (img.shape[0], kerns.shape[1]) + tuple(i*d for i,d in zip(img.shape[2:],self.stride))
4748
desc = dnn.GpuDnnConvDesc(border_mode=self.border, subsample=self.stride, conv_mode=self.mode)(gpu_alloc_empty(*alloc_shape).shape, kerns.shape)
4849
out = gpu_alloc_empty(*alloc_shape)
4950
grad = dnn.GpuDnnConv3dGradI if symb_input.ndim == 5 else dnn.GpuDnnConvGradI
5051
conv_output = grad()(kerns, img, out, desc)
5152

52-
if self.with_bias:
53+
if self.b is not None:
5354
d_shuffle = ('x', 0) + tuple('x') * (symb_input.ndim-2)
54-
conv_output += self.bias.dimshuffle(*d_shuffle)
55+
conv_output += self.b.param.dimshuffle(*d_shuffle)
5556

5657
return conv_output
Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import DeepFried2 as df
2-
from DeepFried2.utils import create_param, create_param_and_grad, aslist
2+
from DeepFried2.utils import aslist
33

44
import numpy as _np
55

@@ -19,16 +19,16 @@ def __init__(self, n_features, eps=1e-5):
1919

2020
self.shape = tuple(aslist(n_features))
2121

22-
self.weight, self.grad_weight = create_param_and_grad(n_features, df.init.const(1), name='W_BN_{}'.format(n_features))
23-
self.bias, self.grad_bias = create_param_and_grad(n_features, df.init.const(0), name='b_BN_{}'.format(n_features))
22+
self.W = self._addparam(n_features, df.init.const(1), name='W_BN_{}'.format(n_features))
23+
self.b = self._addparam(n_features, df.init.const(0), name='b_BN_{}'.format(n_features), decay=False)
2424

25-
self.inference_weight = create_param(n_features, df.init.const(1), name='W_BN_{}_inf'.format(n_features))
26-
self.inference_bias = create_param(n_features, df.init.const(0), name='b_BN_{}_inf'.format(n_features))
25+
self.Winf = self._addparam(n_features, df.init.const(1), name='W_BN_{}_inf'.format(n_features), learn=False)
26+
self.binf = self._addparam(n_features, df.init.const(0), name='b_BN_{}_inf'.format(n_features), learn=False)
2727

2828
# These are buffers for collecting the minibatch statistics.
29-
self.buffer_variance = create_param(n_features, df.init.const(1), name='BN_var_{}'.format(n_features))
30-
self.buffer_mean = create_param(n_features, df.init.const(0), name='BN_mean_{}'.format(n_features))
31-
self.buffer_counts = df.th.shared(_np.asarray(0, dtype=df.floatX), name='BN_count_{}'.format(n_features))
29+
self.buf_var = df.th.shared(_np.full(n_features, 1, df.floatX), name='BN_var_{}'.format(n_features))
30+
self.buf_mean = df.th.shared(_np.full(n_features, 0, df.floatX), name='BN_mean_{}'.format(n_features))
31+
self.buf_count = df.th.shared(_np.asarray(0, dtype=df.floatX), name='BN_count_{}'.format(n_features))
3232

3333
self.eps = eps or 1e-5
3434

@@ -46,6 +46,9 @@ def symb_forward(self, symb_input):
4646

4747
# And for the dimshuffle, similar story. Put 'x' on the axes we're normalizing.
4848
d_shuffle = ['x'] + list(range(len(self.shape))) + ['x']*(symb_input.ndim-len(self.shape)-1)
49+
# Shorthand:
50+
def dshuf(x):
51+
return x.dimshuffle(*d_shuffle)
4952

5053
# For example, for the usual case of images where dimensions are
5154
# (B,C,H,W), axis == [0, 2, 3] and d_shuffle == ['x', 0, 'x', 'x']
@@ -54,42 +57,39 @@ def symb_forward(self, symb_input):
5457
self.batch_mean = df.T.mean(symb_input, axis=axis)
5558
self.batch_var = df.T.var(symb_input, axis=axis)
5659

57-
return (symb_input - self.batch_mean.dimshuffle(*d_shuffle)) / df.T.sqrt(self.batch_var + self.eps).dimshuffle(*d_shuffle) * self.weight.dimshuffle(*d_shuffle) + self.bias.dimshuffle(*d_shuffle)
60+
symb_input = (symb_input - dshuf(self.batch_mean)) / dshuf(df.T.sqrt(self.batch_var + self.eps))
61+
62+
return symb_input * dshuf(self.W.param) + dshuf(self.b.param)
5863
else:
59-
return symb_input * self.inference_weight.dimshuffle(*d_shuffle) + self.inference_bias.dimshuffle(*d_shuffle)
64+
return symb_input * dshuf(self.Winf.param) + dshuf(self.binf.param)
6065

6166
def get_stat_updates(self):
6267
assert (self.batch_mean is not None) and (self.batch_var is not None), "You need to do a forward pass first"
6368

64-
stat_updates = list()
65-
stat_updates.append((self.buffer_mean,
66-
(self.buffer_mean * self.buffer_counts + self.batch_mean) / (self.buffer_counts + 1.0)))
67-
68-
stat_updates.append((self.buffer_variance,
69-
(self.buffer_variance * self.buffer_counts + self.batch_var) / (self.buffer_counts + 1.0)))
70-
71-
stat_updates.append((self.buffer_counts,
72-
self.buffer_counts + 1.0))
73-
74-
return stat_updates
69+
# Update buffer statistics with current batch's statistics.
70+
return [
71+
(self.buf_mean, (self.buf_mean * self.buf_count + self.batch_mean) / (self.buf_count + 1.0)),
72+
(self.buf_var, (self.buf_var * self.buf_count + self.batch_var) / (self.buf_count + 1.0)),
73+
(self.buf_count, self.buf_count + 1.0),
74+
]
7575

7676
def training(self):
7777
df.Module.training(self)
78-
self.buffer_counts.set_value(0)
78+
self.buf_count.set_value(0)
7979
self.batch_mean = None
8080
self.batch_var = None
8181

8282
def evaluate(self):
8383
df.Module.evaluate(self)
84-
self.inference_weight.set_value(self.weight.get_value() / _np.sqrt(self.buffer_variance.get_value() + self.eps))
85-
self.inference_bias.set_value(self.bias.get_value() - self.inference_weight.get_value() * self.buffer_mean.get_value())
84+
self.Winf.set_value(self.W.get_value() / _np.sqrt(self.buf_var.get_value() + self.eps))
85+
self.binf.set_value(self.b.get_value() - self.Winf.get_value() * self.buf_mean.get_value())
8686

8787
def __getstate__(self):
8888
regular = df.Module.__getstate__(self)
89-
return [b.get_value() for b in (self.buffer_mean, self.buffer_variance, self.buffer_counts)] + regular
89+
return [buf.get_value() for buf in (self.buf_mean, self.buf_var, self.buf_count)] + regular
9090

9191
def __setstate__(self, state):
9292
istate = iter(state)
93-
for b, s in zip((self.buffer_mean, self.buffer_variance, self.buffer_counts), istate):
94-
b.set_value(s)
93+
for buf, val in zip((self.buf_mean, self.buf_var, self.buf_count), istate):
94+
buf.set_value(val)
9595
df.Module.__setstate__(self, istate)

0 commit comments

Comments
 (0)