Skip to content

Commit f4ff69a

Browse files
committed
fix test bugs
1 parent 19cc77c commit f4ff69a

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

brainpy/_src/measure/correlation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ def _cc(i, j):
107107
return np.mean(np.asarray(res))
108108

109109

110+
def _f_signal(signal):
111+
return jnp.mean(signal * signal) - jnp.mean(signal) ** 2
112+
113+
110114
def voltage_fluctuation(potentials, numpy=True, method='loop'):
111115
r"""Calculate neuronal synchronization via voltage variance.
112116
@@ -177,15 +181,14 @@ def voltage_fluctuation(potentials, numpy=True, method='loop'):
177181
avg_var = jnp.mean(avg * avg) - jnp.mean(avg) ** 2
178182

179183
if method == 'loop':
180-
_var = lambda aa: bm.for_loop(lambda signal: jnp.mean(signal * signal) - jnp.mean(signal) ** 2,
181-
operands=jnp.moveaxis(aa, 0, 1))
184+
_var = bm.for_loop(_f_signal, operands=jnp.moveaxis(potentials, 0, 1))
182185

183186
elif method == 'vmap':
184-
_var = vmap(lambda signal: jnp.mean(signal * signal) - jnp.mean(signal) ** 2, in_axes=1)
187+
_var = vmap(_f_signal, in_axes=1)(potentials)
185188
else:
186189
raise UnsupportedError(f'Do not support {method}. We only support "loop" or "vmap".')
187190

188-
var_mean = jnp.mean(_var(potentials))
191+
var_mean = jnp.mean(_var)
189192
r = jnp.where(var_mean == 0., 1., avg_var / var_mean)
190193
return bm.as_numpy(r) if numpy else r
191194

0 commit comments

Comments
 (0)