Use the command python main.py --device=cuda --dataset=shakespeare_char to train on a character-level shakespeare dataset.
This is a project which replaces attention in a traditional GPT2-based transformer with my idea, the sub-quadratic-complexity matrix recurrent unit (MRU). This repo is forked from my repo transformer-train-script.
Based on testing on the shakespeare_char toy dataset, the MRU seems to work well as a replacement for attention.

The above loss plot is the first train attempt, using the independent-heads branch of this repo and my other repo https://github.com/mikayahlevi/transformer-train-script.
I have limited compute and experience with training large models, so I haven't been able to test the LM on much other than the toy dataset. Firstly, I would like to test this on larger and more informative datasets. If anyone wants to help me with this or is interested in contacting me, reach out to me at mikayahlevi@gmail.com.
I've also begun writing a CUDA PyTorch extension which can significantly speed up the operation.
The idea of a matrix recurrent unit is dictated by the update rule
- Matrix multiplication is associative but not commutative. The associativity means I can compute the cumulative matrix product using an (inclusive) parallel scan. The lack of commutativity means that the order of the inputs is automatically incorporated into the output by the MRU.
- When you try to do this scan on a traditional RNN, the number of operations scales cubically with the amount of elements in the output state, meaning that limited information is retained compared to the amount of computation. On the other hand, if the states are matrices, the number of operations as a function of elements in the output state is
$((d_o)^2)^\frac{3}{2}$ , where$(d_o)^2$ is the number of elements in the square$d_o \times d_o$ output matrix state. Some more info here: https://arxiv.org/abs/1709.04057. - When processing the tokens sequentially, the network scales linearly with time in contrast to attention which scales quadratically with sequence length.
For the rest of this document, let's call the sequence length
The number of operations for the MRU itself in is:
- Using recurrence
- Using the Brent-Kung scan
- Using the Hillis-Steel scan
- Using the CUDA kernel (Sklansky)
The parallel scans take more computation, but they have the advantage of using parallel hardware more efficiently. While processing recurrently would take
The MRU should take in a sequence of vectors and return a sequence of vectors, like any other traditional operation in a neural network. For now I'll be ignoring the batch and sequence dimensions and only focus on the last dimension.
Therefore,
Recent developments in state space models have branched out from the Mamba 1 and 2, which use the introduction of selectivity and a reformulation of state space systems to bring them closer to linear attention. Models in the selective state space family like Mamba 1/2, RWKV7, and DeltaNet, etc use a significantly reduced state matrix (simply a scalar in the case of Mamba 2) to allow them to avoid the more computationally expensive linear transformations. These changes allow the systems to be expressed as computationally efficient weighted sums of vectors (similar to attention). The MRU takes the opposite approach and drops the added terms from state space systems and instead focuses on making the state matrix selective/data-dependent and efficient, so that each update transforms the last state.
If we assume that a model can only store an amount of information proportional to the accessible number of scalars, the MRU is shown to have a disadvantage. With attention, it's fair to assume that every value is accessible at every future timestep, meaning that attention can store information proportional to the value size times the sequence length. Mamba-like models also store a large amount of information, with a state size by hidden size matrix accessed by a query-like vector at every timestep. Unfortunately, the MRU only has a single square matrix with the number of elements equal to the state size, potentially meaning that the information has to be compressed into a much smaller representation.
For the MRU, I've derived an efficient algorithm using a parallel scan to compute it. Sorry for my (most likely) incorrect mathematical notation. I am not well versed in all of the math fields that this computation involves. Note that the
The forward pass can be computed using a parallel scan.
The backwards pass for the MRU way more complicated.
The gradient of
The expanded gradient of
If we define
I'll call the second part of the gradient a new variable,
You can see
If we let
and
then we can express the equation with
Which can be computed with a reverse parallel scan because matrix multiplication is associative.
Combining all of this, we get the final gradient for the input matrices,