Skip to content

Commit 92633f9

Browse files
committed
update model utils
1 parent 27ae7e2 commit 92633f9

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

ngclearn/utils/model_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def create_function(fun_name, args=None):
8585
if fun_name == "tanh":
8686
fx = tanh
8787
dfx = d_tanh
88-
elif "bkwta" in fun_name:
88+
elif fun_name == "bkwta":
8989
fx = bkwta
9090
dfx = bkwta #d_identity
9191
elif fun_name == "sine":
@@ -651,8 +651,8 @@ def layer_normalize(x, shift=0., scale=1.):
651651
layer-normalized data samples `x`
652652
"""
653653
xmu = jnp.mean(x, axis=1, keepdims=True)
654-
xsigma = jnp.sqrt(jnp.mean(jnp.square(x - xmu)).clip(min=1e-6))
655-
_x = (x - xmu)/(xsigma + 1e-6)
654+
xsigma = jnp.sqrt(jnp.mean(jnp.square(x - xmu)).clip(min=1e-6)).clip(min=1e-6)
655+
_x = (x - xmu) / xsigma
656656
return _x * scale + shift
657657

658658
@jit

0 commit comments

Comments
 (0)