|
10 | 10 | import flowjax.flows |
11 | 11 | import numpy as np |
12 | 12 | from paramax import Parameterize |
| 13 | +from equinox.nn import Linear |
13 | 14 |
|
14 | 15 |
|
15 | 16 | def _generate_sequences(k, r_vals): |
@@ -115,102 +116,6 @@ def _generate_permutations(rng, n_dim, n_layers, max_run=3): |
115 | 116 | return permutations.T, is_in_first.sum(0) |
116 | 117 |
|
117 | 118 |
|
118 | | -# Fix upstream bug for zero-dimensional arrays |
119 | | -class Linear(eqx.Module, strict=True): |
120 | | - """Performs a linear transformation.""" |
121 | | - |
122 | | - weight: jax.Array |
123 | | - bias: jax.Array | None |
124 | | - in_features: Union[int, Literal["scalar"]] = eqx.field(static=True) |
125 | | - out_features: Union[int, Literal["scalar"]] = eqx.field(static=True) |
126 | | - use_bias: bool = eqx.field(static=True) |
127 | | - |
128 | | - def __init__( |
129 | | - self, |
130 | | - in_features: Union[int, Literal["scalar"]], |
131 | | - out_features: Union[int, Literal["scalar"]], |
132 | | - use_bias: bool = True, |
133 | | - dtype=None, |
134 | | - *, |
135 | | - key, |
136 | | - ): |
137 | | - """**Arguments:** |
138 | | -
|
139 | | - - `in_features`: The input size. The input to the layer should be a vector of |
140 | | - shape `(in_features,)` |
141 | | - - `out_features`: The output size. The output from the layer will be a vector |
142 | | - of shape `(out_features,)`. |
143 | | - - `use_bias`: Whether to add on a bias as well. |
144 | | - - `dtype`: The dtype to use for the weight and the bias in this layer. |
145 | | - Defaults to either `jax.numpy.float32` or `jax.numpy.float64` depending |
146 | | - on whether JAX is in 64-bit mode. |
147 | | - - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter |
148 | | - initialisation. (Keyword only argument.) |
149 | | -
|
150 | | - Note that `in_features` also supports the string `"scalar"` as a special value. |
151 | | - In this case the input to the layer should be of shape `()`. |
152 | | -
|
153 | | - Likewise `out_features` can also be a string `"scalar"`, in which case the |
154 | | - output from the layer will have shape `()`. |
155 | | - """ |
156 | | - dtype = np.float32 if dtype is None else dtype |
157 | | - wkey, bkey = jax.random.split(key, 2) |
158 | | - in_features_ = 1 if in_features == "scalar" else in_features |
159 | | - out_features_ = 1 if out_features == "scalar" else out_features |
160 | | - if in_features_ == 0: |
161 | | - lim = 1.0 |
162 | | - else: |
163 | | - lim = 1 / math.sqrt(in_features_) |
164 | | - wshape = (out_features_, in_features_) |
165 | | - self.weight = eqx.nn._misc.default_init(wkey, wshape, dtype, lim) |
166 | | - bshape = (out_features_,) |
167 | | - self.bias = ( |
168 | | - eqx.nn._misc.default_init(bkey, bshape, dtype, lim) if use_bias else None |
169 | | - ) |
170 | | - |
171 | | - self.in_features = in_features |
172 | | - self.out_features = out_features |
173 | | - self.use_bias = use_bias |
174 | | - |
175 | | - @jax.named_scope("eqx.nn.Linear") |
176 | | - def __call__(self, x: jax.Array, *, key=None) -> jax.Array: |
177 | | - """**Arguments:** |
178 | | -
|
179 | | - - `x`: The input. Should be a JAX array of shape `(in_features,)`. (Or shape |
180 | | - `()` if `in_features="scalar"`.) |
181 | | - - `key`: Ignored; provided for compatibility with the rest of the Equinox API. |
182 | | - (Keyword only argument.) |
183 | | -
|
184 | | - !!! info |
185 | | -
|
186 | | - If you want to use higher order tensors as inputs (for example featuring " |
187 | | - "batch dimensions) then use `jax.vmap`. For example, for an input `x` of " |
188 | | - "shape `(batch, in_features)`, using |
189 | | - ```python |
190 | | - linear = equinox.nn.Linear(...) |
191 | | - jax.vmap(linear)(x) |
192 | | - ``` |
193 | | - will produce the appropriate output of shape `(batch, out_features)`. |
194 | | -
|
195 | | - **Returns:** |
196 | | -
|
197 | | - A JAX array of shape `(out_features,)`. (Or shape `()` if |
198 | | - `out_features="scalar"`.) |
199 | | - """ |
200 | | - |
201 | | - if self.in_features == "scalar": |
202 | | - if jnp.shape(x) != (): |
203 | | - raise ValueError("x must have scalar shape") |
204 | | - x = jnp.broadcast_to(x, (1,)) |
205 | | - x = self.weight @ x |
206 | | - if self.bias is not None: |
207 | | - x = x + self.bias |
208 | | - if self.out_features == "scalar": |
209 | | - assert jnp.shape(x) == (1,) |
210 | | - x = jnp.squeeze(x) |
211 | | - return x |
212 | | - |
213 | | - |
214 | 119 | class FactoredMLP(eqx.Module, strict=True): |
215 | 120 | """Standard Multi-Layer Perceptron; also known as a feed-forward network. |
216 | 121 |
|
|
0 commit comments