Skip to content

Commit 58a8710

Browse files
authored
Merge pull request #85 from lucasb-eyer/init-array-shortcut
Init array shortcut
2 parents 582dba0 + 585b4ed commit 58a8710

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

DeepFried2/Param.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,16 @@ class Param(object):
66

77
def __init__(self, shape, init, fan=None, name=None, learn=True, decay=True, dtype=df.floatX, **kw):
88
self.init = init
9-
self.shape = shape
9+
self.shape = (shape,) if _np.isscalar(shape) else tuple(shape)
1010
self.fan = fan
1111
self.decay = decay
1212

13-
val = init(self.shape, self.fan).astype(dtype)
13+
# Support a useful shortcut for initializing with an array-like:
14+
# TODO: It would be nicer to use Python's buffer-interface.
15+
if hasattr(init, 'shape') and hasattr(init, 'dtype'):
16+
self.init = df.init.array(init)
17+
18+
val = self.init(self.shape, self.fan).astype(dtype)
1419
self.param = df.th.shared(val, name=name, **kw)
1520

1621
if learn:

0 commit comments

Comments
 (0)