Skip to content

Commit 943e5fc

Browse files
committed
Remove equinox linear workaround
1 parent d4581cf commit 943e5fc

File tree

1 file changed

+1
-96
lines changed

1 file changed

+1
-96
lines changed

python/nutpie/normalizing_flow.py

Lines changed: 1 addition & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import flowjax.flows
1111
import numpy as np
1212
from paramax import Parameterize
13+
from equinox.nn import Linear
1314

1415

1516
def _generate_sequences(k, r_vals):
@@ -115,102 +116,6 @@ def _generate_permutations(rng, n_dim, n_layers, max_run=3):
115116
return permutations.T, is_in_first.sum(0)
116117

117118

118-
# Fix upstream bug for zero-dimensional arrays
119-
class Linear(eqx.Module, strict=True):
120-
"""Performs a linear transformation."""
121-
122-
weight: jax.Array
123-
bias: jax.Array | None
124-
in_features: Union[int, Literal["scalar"]] = eqx.field(static=True)
125-
out_features: Union[int, Literal["scalar"]] = eqx.field(static=True)
126-
use_bias: bool = eqx.field(static=True)
127-
128-
def __init__(
129-
self,
130-
in_features: Union[int, Literal["scalar"]],
131-
out_features: Union[int, Literal["scalar"]],
132-
use_bias: bool = True,
133-
dtype=None,
134-
*,
135-
key,
136-
):
137-
"""**Arguments:**
138-
139-
- `in_features`: The input size. The input to the layer should be a vector of
140-
shape `(in_features,)`
141-
- `out_features`: The output size. The output from the layer will be a vector
142-
of shape `(out_features,)`.
143-
- `use_bias`: Whether to add on a bias as well.
144-
- `dtype`: The dtype to use for the weight and the bias in this layer.
145-
Defaults to either `jax.numpy.float32` or `jax.numpy.float64` depending
146-
on whether JAX is in 64-bit mode.
147-
- `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
148-
initialisation. (Keyword only argument.)
149-
150-
Note that `in_features` also supports the string `"scalar"` as a special value.
151-
In this case the input to the layer should be of shape `()`.
152-
153-
Likewise `out_features` can also be a string `"scalar"`, in which case the
154-
output from the layer will have shape `()`.
155-
"""
156-
dtype = np.float32 if dtype is None else dtype
157-
wkey, bkey = jax.random.split(key, 2)
158-
in_features_ = 1 if in_features == "scalar" else in_features
159-
out_features_ = 1 if out_features == "scalar" else out_features
160-
if in_features_ == 0:
161-
lim = 1.0
162-
else:
163-
lim = 1 / math.sqrt(in_features_)
164-
wshape = (out_features_, in_features_)
165-
self.weight = eqx.nn._misc.default_init(wkey, wshape, dtype, lim)
166-
bshape = (out_features_,)
167-
self.bias = (
168-
eqx.nn._misc.default_init(bkey, bshape, dtype, lim) if use_bias else None
169-
)
170-
171-
self.in_features = in_features
172-
self.out_features = out_features
173-
self.use_bias = use_bias
174-
175-
@jax.named_scope("eqx.nn.Linear")
176-
def __call__(self, x: jax.Array, *, key=None) -> jax.Array:
177-
"""**Arguments:**
178-
179-
- `x`: The input. Should be a JAX array of shape `(in_features,)`. (Or shape
180-
`()` if `in_features="scalar"`.)
181-
- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
182-
(Keyword only argument.)
183-
184-
!!! info
185-
186-
If you want to use higher order tensors as inputs (for example featuring "
187-
"batch dimensions) then use `jax.vmap`. For example, for an input `x` of "
188-
"shape `(batch, in_features)`, using
189-
```python
190-
linear = equinox.nn.Linear(...)
191-
jax.vmap(linear)(x)
192-
```
193-
will produce the appropriate output of shape `(batch, out_features)`.
194-
195-
**Returns:**
196-
197-
A JAX array of shape `(out_features,)`. (Or shape `()` if
198-
`out_features="scalar"`.)
199-
"""
200-
201-
if self.in_features == "scalar":
202-
if jnp.shape(x) != ():
203-
raise ValueError("x must have scalar shape")
204-
x = jnp.broadcast_to(x, (1,))
205-
x = self.weight @ x
206-
if self.bias is not None:
207-
x = x + self.bias
208-
if self.out_features == "scalar":
209-
assert jnp.shape(x) == (1,)
210-
x = jnp.squeeze(x)
211-
return x
212-
213-
214119
class FactoredMLP(eqx.Module, strict=True):
215120
"""Standard Multi-Layer Perceptron; also known as a feed-forward network.
216121

0 commit comments

Comments
 (0)