Skip to content

Commit dc867de

Browse files
committed
Renames trainable to learnable.
1 parent e309711 commit dc867de

File tree

4 files changed

+9
-9
lines changed

4 files changed

+9
-9
lines changed

DeepFried2/Module.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ def _addparam_optional(self, shape, init, *a, **kw):
4141

4242

4343
def zero_grad_parameters(self):
44-
for p in self.parameters(trainable_only=True):
44+
for p in self.parameters(learnable_only=True):
4545
p.zero_grad()
4646

47-
def parameters(self, trainable_only=False):
47+
def parameters(self, learnable_only=False):
4848
params = getattr(self, '_params', [])
49-
if trainable_only:
50-
params = [p for p in params if p.trainable()]
49+
if learnable_only:
50+
params = [p for p in params if p.learnable()]
5151
return params
5252

5353
def evaluate(self):
@@ -88,7 +88,7 @@ def accumulate_gradients(self, data_in, data_tgt, crit):
8888
symb_cost = crit(symb_out, symb_tgt)
8989
extra_out = self.get_extra_outputs() + crit.get_extra_outputs()
9090

91-
params = self.parameters(trainable_only=True)
91+
params = self.parameters(learnable_only=True)
9292
symb_grads = df.th.grad(cost=symb_cost, wrt=[p.param for p in params])
9393
grads_updates = [(p.grad, p.grad + symb_grad) for p, symb_grad in zip(params, symb_grads)]
9494

DeepFried2/Optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def update_parameters(self, model):
1212
if model not in self.states:
1313
# TODO: Not only scalar, e.g. Adam might profit from integer t
1414
hyperparams = {name: df.T.scalar(name) for name in self.hyperparams}
15-
params, grads = zip(*[(p.param, p.grad) for p in model.parameters(trainable_only=True)])
15+
params, grads = zip(*[(p.param, p.grad) for p in model.parameters(learnable_only=True)])
1616
updates = self.get_updates(params, grads, **hyperparams)
1717
self.states[model] = df.th.function(
1818
inputs=list(hyperparams.values()),

DeepFried2/Param.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,5 @@ def zero_grad(self):
4141
def may_decay(self):
4242
return self.grad is not None and self.decay
4343

44-
def trainable(self):
44+
def learnable(self):
4545
return self.grad is not None

DeepFried2/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def tensors_for_ndarrays(datas, basename):
3737
raise TypeError("I only understand lists or tuples of numpy arrays! (possibly nested)")
3838

3939

40-
def count_params(module, trainable_only=True):
41-
return sum(p.get_value().size for p in module.parameters(trainable_only=trainable_only))
40+
def count_params(module, learnable_only=True):
41+
return sum(p.get_value().size for p in module.parameters(learnable_only=learnable_only))
4242

4343

4444
def flatten(what, types=(list, tuple), none_to_empty=False):

0 commit comments

Comments
 (0)