We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 796178d commit 84237ffCopy full SHA for 84237ff
ngclearn/utils/model_utils.py
@@ -534,6 +534,19 @@ def threshold_cauchy(x, lmbda):
534
535
@jit
536
def layer_normalize(x, shift=0., scale=1.):
537
+ """
538
+ Applies layer normalization to input data `x`
539
+
540
+ Args:
541
+ x: data to apply threshold function over
542
543
+ shift: the compensating mean/shift factor/parameters (to undo mean subtraction)
544
545
+ scale: the compensating re-scaling factor/parameters (to undo standard deviation division)
546
547
+ Returns:
548
+ layer-normalized data samples `x`
549
550
xmu = jnp.mean(x, axis=1, keepdims=True)
551
xsigma = jnp.sqrt(jnp.mean(jnp.square(x - xmu)))
552
_x = (x - xmu)/(xsigma + 1e-6)
0 commit comments