NeST-S6: Nested Convolutional Spatiotemporal (PDE-aware) State-space Model for 5G Network Traffic Forecasting
NeST-S6 (“Nest S 6”) is a nested-learning spatiotemporal forecasting model designed for grid-based cellular traffic prediction, inspired in part by Google Research’s Nested Learning paradigm (arXiv:2512.24695, blog post). It integrates a fast per-step spatiotemporal predictor (convolution + windowed 2D attention + PDE-aware selective state-space updates) with a slow persistent memory updated by a learned Deep Optimizer.
-
Nested Learning Memory: A slow, persistent spatial memory improves robustness under global/dynamic drift, inspired by “Nested Learning: The Illusion of Deep Learning Architectures” (paper, blog post).
-
PDE-aware SSM Core: Spatially-varying selective state-space updates with stable exponential discretization.
-
Local Mixing + Windowed Attention: Depthwise convolution and windowed attention for efficient 2D context.
-
End-to-end Grid Forecasting: Predict full spatial kernels/patches at once (instead of scalar-per-pixel).
-
Autoregressive Rollout: Built-in multi-step forecasting via recursive prediction.
Let
The goal is to predict the future spatial kernel at the next time step:
We learn parameters
NeST-S6 is a nested-learning spatiotemporal predictor with two interacting loops:
- Fast Learner: per-step dynamics (local mixing + windowed attention + a spatial selective state-space core)
- Slow Learner: a persistent spatial memory updated by a learned “Deep Optimizer”
Below we summarize the core math (same notation as the paper section you referenced), including tensor shapes.
Given a traffic grid sequence
A convolutional stem projects
Each NeST-S6 Block then applies: (i) depthwise convolution + SiLU for local mixing, (ii) windowed 2D attention for local context, and (iii) an S-PDE SSM core update (S for spatial, PDE-aware selective state-space update).
Let
Dynamic low-rank parameter generation. A
The per-channel step size is lifted to
We use learnable low-rank bases
We then expand the rank-$r$ coefficient maps into full spatially-varying tensors (
This yields $\mathbf{B}{\text{dyn}},\mathbf{C}{\text{dyn}}\in\mathbb{R}^{B\times D_{\text{in}}\times D_s\times H\times W}$.
With learned gates $g_B=\sigma(\text{Conv}{1\times 1}(\mathbf{x}))$ and $g_C=\sigma(\text{Conv}{1\times 1}(\mathbf{x}))$, we define effective parameters (with a learned scale
Stable transition parameterization. The base transition is parameterized for stability as:
Optionally, apply a low-rank
and set (with learned
Exponential discretization (per-pixel selective update). The state update is:
The output features (summing over the state dimension) include a skip term:
NeST-S6 maintains a long-term spatial memory:
updated by an outer-loop learned optimizer. At each step, the Fast Learner’s prediction is compared against the incoming frame to generate a surprise signal, based on the one-step error:
We project this error into the model dimension using a small convolutional stem (a
The Deep Optimizer (2D)
and the memory is updated with a learned decay factor
Intuitively, the Slow Learner performs a form of online adaptation by updating a global memory more aggressively when the Fast Learner’s one-step prediction error is high.
Memory is fused back into the Fast Learner context through a learned Context Gate. Given
and $\tilde{\mathbf{z}}t$ is passed through the fast stack to obtain the next prediction $\hat{\mathbf{x}}{t+1}$ via a convolutional prediction head.
- NeST: Nested Spatiotemporal (nested fast/slow learners).
- S6: indicates a selective state-space core (SSM) in the “S4/S5/S6-style” family, adapted here to 2D spatial grids with per-pixel (spatially varying) parameters.
git clone git@github.com:ZineddineBtc/NeST-S6.git
cd NeST-S6
pip install -e .- Python 3.10
- PyTorch==2.6.0
import torch
from nest_s6 import NeST_S6
# Input tensor: (B, T, H, W)
B, T, H, W = 2, 6, 20, 20
x = torch.randn(B, T, H, W)
model = NeST_S6(n_layers=2, d_model=48, d_state=8, d_conv=3, expand=2, attn_window=8, a_mod_rank=2)
with torch.no_grad():
# last-frame prediction: (B, H, W)
y_last = model(x)
print("Last prediction:", y_last.shape)
# full sequence prediction: (B, T, H, W)
y_seq = model(x, return_sequence=True)
print("Sequence prediction:", y_seq.shape)
# autoregressive rollout (+2): (B, T+2, H, W)
y_roll = model(x, steps_to_predict=2, return_sequence=True)
print("Rollout prediction:", y_roll.shape)What’s special about NeST-S6 training is the nested memory loop:
-
Fast learner: predicts the next frame
$\hat{\mathbf{x}}_{t+1}$ from recent history. -
Surprise signal: compute prediction error $\mathbf{e}t = \mathbf{x}{t+1}-\hat{\mathbf{x}}_{t+1}$ and project it to a latent “surprise”
$\mathbf{S}_t$ . - Slow learner (memory write): update a persistent memory $\mathbf{M}t$ using a learned optimizer and decay: $\mathbf{M}{t+1}=\lambda\mathbf{M}_t+(1-\lambda)\Delta\mathbf{M}_t$, where $\Delta\mathbf{M}t=\phi{\text{opt}}(\mathbf{z}_t,\mathbf{S}_t)$.
In practice this means memory updates are error-driven during training (teacher-forced with access to
PyTorch-style sketch:
import torch
import torch.nn.functional as F
model.train()
optimizer.zero_grad()
# batch_X: (B, T, H, W), batch_y: (B, H, W) # next-frame target
pred = model(batch_X) # internally uses fast dynamics + (optionally) updates slow memory from surprise
# Main one-step forecasting objective (e.g., Huber / SmoothL1)
loss_main = F.smooth_l1_loss(pred, batch_y, beta=1.0)
# Optional physics-inspired spatial regularizer (example: Laplacian consistency)
# loss_lapl = laplacian_loss(pred, batch_y)
# loss = loss_main + w_lapl * loss_lapl
loss = loss_main
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()If you want to reproduce the “nested-learning” behavior precisely, the key requirement is: the memory update must consume a surprise derived from the one-step prediction error during training; without that, you’re effectively training only the fast learner.
If you use NeST-S6 in your research, please cite the (pending-review) paper and/or the repository:
@software{Bettouche2026NeSTS6Repo,
title = {NeST-S6 (Reference Implementation)},
author = {Bettouche, Zineddine and Ali, Khalid and Fischer, Andreas and Kassler, Andreas},
year = {2026},
url = {https://github.com/ZineddineBtc/NeST-S6},
}
NeST-S6/
├── nest_s6/
│ ├── __init__.py # Exports NeST_S6
│ ├── model.py # Core NeST-S6 architecture
│ ├── __version__.py
├── nests6-arch.png
├── pyproject.toml
├── README.md
