1+ """
2+ This module implements the `LogLinearCDE` class using JAX and Equinox. The model is a
3+ block-diagonal Linear Controlled Differential Equation (CDE), where the output is
4+ approximated during training using the Log-ODE method.
5+
6+ Attributes of the `LogLinearCDE` model:
7+ - `init_layer`: The linear layer used to initialize the hidden state $h_0$ from the input $x_0$.
8+ - `out_layer`: The linear layer used to produce final predictions from the hidden state.
9+ - `vf_A`: Learnable parameters for the linear vector field, shaped as flattened block matrices.
10+ - `hidden_dim`: The dimension of the hidden state $h_t$.
11+ - `block_size`: Size of each square block in the block-diagonal vector field.
12+ - `num_blocks`: Number of blocks, computed as `hidden_dim // block_size`.
13+ - `parallel_steps`: Number of log-flow matrices composed in parallel (using associative scan).
14+ - `logsig_depth`: The depth of the log-signature used in the Log-ODE method.
15+ - `basis_list`: The list of basis elements of the free Lie algebra up to the specified depth.
16+ - `lambd`: Regularization parameter applied to vector field scaling.
17+ - `w_init_std`: Standard deviation for the initial weights of the vector field.
18+ - `classification`: Boolean indicating if the model is used for classification tasks.
19+
20+ The class includes:
21+ - `log_ode`: Method for computing the iterated Lie brackets of the linear vector fields.
22+ - `__call__`: Performs the forward pass, where flows are composed and applied to the hidden state
23+ either step-by-step or in parallel (using associative scan), followed by output projection.
24+ """
25+
126from __future__ import annotations
227
328from typing import List , Tuple
@@ -28,10 +53,6 @@ def depth(b):
2853
2954
3055class LogLinearCDE (eqx .Module ):
31- """
32- Block‑diagonal Linear Controlled Differential Equation layer.
33- """
34-
3556 init_layer : eqx .nn .Linear
3657 out_layer : eqx .nn .Linear
3758 vf_A : jnp .ndarray
@@ -41,10 +62,10 @@ class LogLinearCDE(eqx.Module):
4162 parallel_steps : int
4263 logsig_depth : int
4364 basis_list : List [Tuple [int , ...]]
44- stepsize : int
4565 lambd : float
66+ w_init_std : float
67+ classification : bool
4668
47- classification : bool = True
4869 lip2 : bool = True
4970 nondeterministic : bool = False
5071 stateful : bool = False
@@ -57,10 +78,10 @@ def __init__(
5778 label_dim : int ,
5879 block_size : int ,
5980 logsig_depth : int ,
60- stepsize : int ,
6181 lambd : float = 1.0 ,
6282 w_init_std : float = 0.25 ,
6383 parallel_steps : int = 128 ,
84+ classification : bool = True ,
6485 key ,
6586 ):
6687 if hidden_dim % block_size != 0 :
@@ -70,24 +91,25 @@ def __init__(
7091 self .num_blocks = hidden_dim // block_size
7192 self .parallel_steps = parallel_steps
7293 self .logsig_depth = logsig_depth
73- self .stepsize = stepsize
7494 ctx = rp .get_context (width = data_dim , depth = self .logsig_depth , coeffs = rp .DPReal )
7595 basis = ctx .lie_basis
7696 basis_list = []
7797 for i in range (basis .size (self .logsig_depth )):
7898 basis_list .append (eval (str (basis .index_to_key (i ))))
7999 self .basis_list = basis_list
80100 self .lambd = lambd
101+ self .w_init_std = w_init_std
81102
82103 k_init , k_A , k_B = jr .split (key , 3 )
83104 self .init_layer = eqx .nn .Linear (data_dim , hidden_dim , key = k_init )
84105 self .out_layer = eqx .nn .Linear (hidden_dim , label_dim , key = k_B )
85106
86107 self .vf_A = (
87108 jr .normal (k_A , (data_dim + 1 , self .num_blocks * block_size * block_size ))
88- * w_init_std
109+ * self . w_init_std
89110 / jnp .sqrt (block_size )
90111 )
112+ self .classification = classification
91113
92114 def log_ode (self , vf ):
93115
@@ -111,7 +133,7 @@ def log_ode(self, vf):
111133
112134 left_indices = []
113135 right_indices = []
114- for ( i_b , b ) in curr_elements :
136+ for i_b , b in curr_elements :
115137 u_tuple = to_tuple (b [0 ])
116138 v_tuple = to_tuple (b [1 ])
117139 i_u = basis_index [u_tuple ]
@@ -184,7 +206,12 @@ def parallel_step(y, flows):
184206 inp_rem = flows [- remainder :]
185207 _ , y_rem = jax .lax .scan (step , ys [- 1 ], inp_rem )
186208 ys = jnp .vstack ([ys , y_rem ])
187- ys = jnp .mean (ys , axis = 0 )
188- ys = self .out_layer (ys )
189- preds = jax .nn .softmax (ys )
209+
210+ if self .classification :
211+ ys = jnp .mean (ys , axis = 0 )
212+ preds = jax .nn .softmax (self .out_layer (ys ))
213+ else :
214+ ys = jax .vmap (self .out_layer )(ys )
215+ preds = jnp .tanh (ys )
216+
190217 return preds
0 commit comments