|
2 | 2 | import jax |
3 | 3 | import jax.numpy as jnp |
4 | 4 | from beartype.typing import Any, Hashable, Sequence |
| 5 | +from equinox import Module, field |
5 | 6 | from equinox.nn import State |
6 | 7 | from jaxtyping import Array, Float, PRNGKeyArray |
7 | 8 |
|
@@ -73,8 +74,8 @@ def __init__( |
73 | 74 | eps: float = 1e-5, |
74 | 75 | momentum: float = 0.1, |
75 | 76 | affine: bool = True, |
76 | | - inference: bool = False, |
77 | 77 | dtype: Any | None = None, |
| 78 | + inference: bool = False, |
78 | 79 | ): |
79 | 80 | if dtype is None: |
80 | 81 | dtype = default_floating_dtype() |
@@ -364,3 +365,100 @@ def __call__(self, x: Array, *_, key: PRNGKeyArray | None = None, **__) -> Array |
364 | 365 |
|
365 | 366 | out = out.astype(orig_dtype) |
366 | 367 | return out |
| 368 | + |
| 369 | + |
| 370 | +class ResidualLayerNorm(Module): |
| 371 | + """Layer normalisation with a residual scale parameter. |
| 372 | +
|
| 373 | + Normalises the input by subtracting the mean and dividing by the standard |
| 374 | + deviation computed over the entire array. The learnable affine scale |
| 375 | + parameter is formulated as a residual ``1 + weight``, where ``weight`` is |
| 376 | + initialised to zero. |
| 377 | +
|
| 378 | + Unlike ``LayerNorm``, this module expects the input to exactly match the |
| 379 | + configured ``shape`` and does not automatically broadcast over leading |
| 380 | + batch dimensions; use ``jax.vmap`` for batched inputs. |
| 381 | +
|
| 382 | + Computation is performed at a higher precision (at least ``float32``) and |
| 383 | + the result is cast back to the original dtype. |
| 384 | +
|
| 385 | + Args: |
| 386 | + shape: The exact shape of the unbatched input array. Pass a single |
| 387 | + ``int`` for the common 1-D case. |
| 388 | + eps: Small constant added to the variance for numerical stability. |
| 389 | + Defaults to ``1e-5``. |
| 390 | + use_weight: If ``True``, learn a per-element residual scale parameter |
| 391 | + initialised to ``0``. Defaults to ``True``. |
| 392 | + use_bias: If ``True``, learn a per-element bias parameter initialised |
| 393 | + to ``0``. Defaults to ``False``. |
| 394 | + dtype: Floating-point dtype for the affine parameters. Defaults to |
| 395 | + ``None``. |
| 396 | +
|
| 397 | + Raises: |
| 398 | + ValueError: If the input shape does not exactly match ``shape``. |
| 399 | +
|
| 400 | + Example: |
| 401 | + >>> import jax |
| 402 | + >>> import jax.numpy as jnp |
| 403 | + >>> rln = ResidualLayerNorm(shape=64) |
| 404 | + >>> x = jnp.ones((10, 64)) |
| 405 | + >>> jax.vmap(rln)(x).shape |
| 406 | + (10, 64) |
| 407 | + """ |
| 408 | + |
| 409 | + shape: tuple[int, ...] = field(static=True) |
| 410 | + eps: float = field(static=True) |
| 411 | + use_weight: bool = field(static=True) |
| 412 | + use_bias: bool = field(static=True) |
| 413 | + weight: Float[Array, "*shape"] | None |
| 414 | + bias: Float[Array, "*shape"] | None |
| 415 | + |
| 416 | + def __init__( |
| 417 | + self, |
| 418 | + shape: int | Sequence[int], |
| 419 | + eps: float = 1e-5, |
| 420 | + use_weight: bool = True, |
| 421 | + use_bias: bool = False, |
| 422 | + dtype=None, |
| 423 | + ): |
| 424 | + if isinstance(shape, int): |
| 425 | + shape = (shape,) |
| 426 | + else: |
| 427 | + shape = tuple(shape) |
| 428 | + self.shape = shape |
| 429 | + self.eps = eps |
| 430 | + self.use_weight = use_weight |
| 431 | + self.use_bias = use_bias |
| 432 | + self.weight = jnp.zeros(shape, dtype=dtype) if use_weight else None |
| 433 | + self.bias = jnp.zeros(shape, dtype=dtype) if use_bias else None |
| 434 | + |
| 435 | + def __call__( |
| 436 | + self, |
| 437 | + x: Float[Array, "*shape"], |
| 438 | + *, |
| 439 | + key: PRNGKeyArray | None = None, |
| 440 | + ) -> Array: |
| 441 | + if x.shape != self.shape: |
| 442 | + raise ValueError( |
| 443 | + f"Expected shape {self.shape}, got {x.shape}. You might need jax.vmap." |
| 444 | + ) |
| 445 | + |
| 446 | + orig_dtype = x.dtype |
| 447 | + with jax.numpy_dtype_promotion("standard"): |
| 448 | + dtype = jnp.result_type(x.dtype, jnp.float32) |
| 449 | + |
| 450 | + x = x.astype(dtype) |
| 451 | + mean = jnp.mean(x, keepdims=True) |
| 452 | + variance = jnp.var(x, keepdims=True) |
| 453 | + variance = jnp.maximum(0.0, variance) |
| 454 | + inv = jax.lax.rsqrt(variance + self.eps) |
| 455 | + out = (x - mean) * inv |
| 456 | + |
| 457 | + if self.use_weight: |
| 458 | + assert self.weight is not None |
| 459 | + out = (1.0 + self.weight.astype(dtype)) * out |
| 460 | + if self.use_bias: |
| 461 | + assert self.bias is not None |
| 462 | + out = out + self.bias.astype(dtype) |
| 463 | + |
| 464 | + return out.astype(orig_dtype) |
0 commit comments