Skip to content

Commit e3b2687

Browse files
committed
minor edits and doc edit
1 parent 4b4a88b commit e3b2687

File tree

4 files changed

+18
-6
lines changed

4 files changed

+18
-6
lines changed

ngclearn/engine/cables/dcable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __init__(self, inp, out, init_kernels=None, shared_param_path=None,
130130
if self.shared_param_path is not None:
131131
cable_to_mirror, path_type = shared_param_path
132132
"""
133-
path_type = A, A^T, -A^T, A^T+b, -A^T+b
133+
path_type = A, A^T, A+b, -A^T, A^T+b, -A^T+b
134134
"""
135135
self.path_type = path_type
136136
A = cable_to_mirror.params["A"]

ngclearn/engine/nodes/node.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,21 @@ def compile(self):
7575
comp_dim = curr_comp.shape[1]
7676
comp_var_name = curr_comp.name
7777
comp_var_name = comp_var_name[0:curr_comp.name.index(":"):1]
78-
self.compartments[cname] = \
79-
tf.Variable(tf.zeros([self.batch_size,comp_dim]), name=comp_var_name)
78+
if self.do_inplace == True:
79+
self.compartments[cname] = \
80+
tf.Variable(tf.zeros([self.batch_size,comp_dim]), name=comp_var_name)
81+
else:
82+
self.compartments[cname] = tf.zeros([self.batch_size,comp_dim])
8083
for mname in self.mask_names:
8184
curr_mask = self.masks.get(mname)
8285
if curr_mask is not None:
8386
mask_var_name = curr_mask.name
8487
mask_var_name = mask_var_name[0:curr_mask.name.index(":"):1]
8588
mask_dim = curr_mask.shape[1]
86-
self.masks[mname] = tf.Variable(tf.ones([self.batch_size,mask_dim]), name=mask_var_name)
89+
if self.do_inplace == True:
90+
self.masks[mname] = tf.Variable(tf.ones([self.batch_size,mask_dim]), name=mask_var_name)
91+
else:
92+
self.masks[mname] = tf.ones([self.batch_size,mask_dim])
8793
info["object_type"] = self.node_type
8894
info["object_name"] = self.name
8995
info["n_connected_cables"] = len(self.connected_cables)

ngclearn/generator/temporal/noisy_sin.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@ def sample(self):
4040
"""
4141
# increment the internal tracking of x
4242
x = self.x_prev + self.dt
43+
x = x.astype(np.float32)
4344
# compute the Gaussian/white-noise corrupted sine wave output
4445
y = np.sin(x) + np.random.normal(size=self.sigma.shape) * self.sigma
46+
y = y.astype(np.float32)
4547
# Store x sample into x_prev state
4648
self.x_prev = x # this step ensures that the next noise sample is dependent upon current one
47-
return x
49+
return y
4850

4951
def reset(self):
5052
"""
@@ -53,4 +55,6 @@ def reset(self):
5355
if self.x_initial is not None:
5456
self.x_prev = self.x_initial
5557
else: # reset the noise process back to zero
56-
self.x_prev = np.zeros_like(self.mean)
58+
self.x_prev = np.zeros_like(self.sigma)
59+
self.x_prev = self.x_prev.astype(np.float32)
60+

ngclearn/generator/temporal/oh_process.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def sample(self):
4848
+ self.theta * (self.mean - self.x_prev) * self.dt
4949
+ self.std_dev * np.sqrt(self.dt) * np.random.normal(size=self.mean.shape)
5050
)
51+
x = x.astype(np.float32)
5152
# Store x sample into x_prev state
5253
self.x_prev = x # this step ensures that the next noise sample is dependent upon current one
5354
return x
@@ -60,3 +61,4 @@ def reset(self):
6061
self.x_prev = self.x_initial
6162
else: # reset the noise process back to zero
6263
self.x_prev = np.zeros_like(self.mean)
64+
self.x_prev = self.x_prev.astype(np.float32)

0 commit comments

Comments
 (0)