A collection of utilities for multi-layer perceptron models.
pip install -e ./mlp-utils
A collection of activation functions for MLPs.
ReluSquared
:max(0, x)^2
, with an option to be signed.Gelu2
:GELU(x)^2
BSiLU
:(x + α) * sigmoid(x) - α / 2
NeLU
:-α / (1 + x^2)
, often used as a backward function in STE.Sugar
: A straight-through estimator that uses the backward function only for the negative part of the input.StraightThroughEstimator
: A generic straight-through estimator that can be configured with different forward and backward passes.ReluNelu
: An activation that uses ReLU in the forward pass and NeLU in the backward pass for the negative part, using theSugar
module.SugarReLU
: A straight-through estimator with a ReLU forward pass and a sigmoid backward pass.
Utilities for initializing weights in neural network layers.
initialize_weights
: Initializes weights of a module with strategies like "default", "gating", "embedding", or "expert".apply_initialization
: Applies initialization to all modules in a model.create_initializer
: Creates a customized initializer function.
from mlp_utils.layers.init_weights import (
initialize_weights,
apply_initialization,
create_initializer,
)
# Create and apply a reusable initializer
initializer = create_initializer(init_method="default", nonlinearity="relu", scale=1.0)
model.apply(initializer)
A feed-forward block with optional GLU variants.
from mlp_utils.layers.feedforward import FeedForward
ffn = FeedForward(
dim=256,
mult=4,
glu_variant="swiglu",
)
The NGPT
class implements the feed-forward block from the paper "nGPT: Normalized Transformer with Representation Learning on the Hypersphere.".
This module applies the nGPT update rule, which involves normalizing hidden states and using a learnable interpolation parameter (alpha_m
) to update the representation on the hypersphere.
You can use it as a standalone layer:
from mlp_utils.layers import NGPT
# Initialize the nGPT feed-forward block
ngpt_feedforward = NGPT(
dim=256,
)
# The resulting module can be used as a drop-in replacement for a standard feedforward
Alternatively, you can provide your own feed-forward network, which will be wrapped with the nGPT update rule:
from mlp_utils.layers import FeedForward, NGPT
# 1. Create a custom feed-forward network
feedforward_net = FeedForward(
dim=256,
mult=4,
glu_variant="swiglu",
)
# 2. Wrap it with the NGPT layer
ngpt_feedforward_wrapped = NGPT(
feedforward_net=feedforward_net,
dim=256,
)
The FastFeedForward
class implements the Fast Feedforward Network from the paper "Fast Feedforward Networks" by Belcak and Wattenhofer. This layer uses a tree of routers to dynamically select a small subset of "expert" FeedForward
networks for each input token, enabling conditional computation.
This architecture achieves logarithmic time complexity with respect to the number of experts, leading to significant efficiency gains while preserving a high degree of predictive performance.
from mlp_utils.layers import FastFeedForward
# Create a FastFeedForward layer with a tree of depth 3 (2^3 = 8 experts)
fff = FastFeedForward(
dim=256,
depth=3,
mult=4,
glu_variant="swiglu",
)
The PathWeightedFFF
class implements a hierarchical, path-dependent neural network that uses a binary tree structure. Unlike a Mixture-of-Experts (MoE) model that routes an input to a single expert, this network computes its output by combining transformations from every node along the traversed path.
The routing logits themselves are activated with GELU and used as weights to combine the transformations, allowing the model to learn hierarchical features in a path-dependent manner. This offers a different architectural trade-off compared to sparse MoE layers like FastFeedForward
.
from mlp_utils.layers import PathWeightedFFF
# Create a PathWeightedFFF layer with a tree of depth 4
pfff = PathWeightedFFF(
input_width=256,
depth=4,
output_width=256,
)
The SwitchFFN
layer implements the Switch Transformer feed-forward layer from the paper "Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity".
This layer uses a Mixture-of-Experts (MoE) approach, where a router network dynamically sends each input token to one of several "expert" FeedForward
networks. This allows for a massive number of parameters while keeping the computational cost per token constant. The layer also includes an auxiliary load-balancing loss to encourage experts to be utilized evenly.
from mlp_utils.layers import SwitchFFN
# Create a SwitchFFN layer with 8 experts
switch_ffn = SwitchFFN(
dim=256,
num_experts=8,
capacity_factor=1.25,
ff_kwargs=dict(mult=4, glu_variant="swiglu"),
)
# The forward pass returns the output and the load-balancing loss
output, loss = switch_ffn(input_tensor)
A standardized gating mechanism.
from mlp_utils.layers.gating import GatingMechanism
gate = GatingMechanism(
input_dim=256,
bottleneck_factor=0.5,
)
Gated Linear Units (GLUs) are feed-forward layers that use multiplicative gating to improve model expressivity. Each GLU variant uses a different activation function for the gate.
GLU
: Gated Linear Unit with sigmoid activationBilinear
: GLU with no activation (identity)ReGLU
: GLU with ReLU activationSwiGLU
: GLU with Swish (SiLU) activation - commonly used in modern LLMsGeGLU
: GLU with GELU activation
from mlp_utils.layers.glu import SwiGLU
# Standard GLU with separate gate and value projections
swiglu = SwiGLU(
dim_in=256,
dim_out=512,
bias=True,
)
MGLUs are a memory-efficient variant of GLUs that use a single shared weight matrix for both gate and value projections, with learnable binary masks to differentiate between them. This reduces memory bandwidth requirements during inference while maintaining the expressive power of traditional GLUs.
Based on the paper "Masked Gated Linear Unit" by Tajima et al. (2025), MGLUs implement a "Mixture of Element-wise Gating" (MoEG) architecture that can provide significant memory and computational benefits when implemented with optimized kernels like FlashMGLU.
Available MGLU variants:
MGLU
: Masked GLU with sigmoid activationBilinearMGLU
: Masked GLU with no activation (identity)ReMGLU
: Masked GLU with ReLU activationSwiMGLU
: Masked GLU with Swish (SiLU) activationGeMGLU
: Masked GLU with GELU activation
from mlp_utils.layers.glu import SwiMGLU
# Memory-efficient MGLU with shared weight matrix
swimglu = SwiMGLU(
dim_in=256,
dim_out=512,
bias=True,
)
# The mask parameter is learned during training
print(f"Learnable mask shape: {swimglu.mask.shape}") # (512,)
Note: This implementation provides the MGLU architecture in PyTorch. The significant performance gains reported in the paper require optimized CUDA kernels like FlashMGLU.
A standardized MLP module.
from mlp_utils.layers.mlp import MLP
mlp = MLP(
input_dim=256,
output_dim=256,
)
The GMLP
class implements the gMLP model from the paper "Pay Attention to MLPs."
from mlp_utils.layers.gmlp import GMLP
gmlp = GMLP(
dim=256,
dim_ff=1024,
seq_len=64,
depth=6,
)
L2Norm
: Normalizes a tensor to have a unit L2 norm along a given dimension.
from mlp_utils import L2Norm
norm = L2Norm(dim=-1)
ResidualWrapper
: Adds a residual connection to any module.
from mlp_utils.layers import ResidualWrapper
import torch.nn as nn
residual_mlp = ResidualWrapper(
nn.Sequential(
nn.Linear(256, 256),
nn.GELU(),
nn.Linear(256, 256),
)
)
Model | Compile | Params | Runtime (s) | Configuration | Final Loss | Status |
---|---|---|---|---|---|---|
mlp | True | 33.41K | 10.15 | act_fn=GELU |
0.221746 | Success |
mlp | True | 33.41K | 2.40 | act_fn=ReLU |
0.225678 | Success |
mlp | True | 33.41K | 2.39 | act_fn=SiLU |
0.215601 | Success |
mlp | True | 33.41K | 2.35 | act_fn=ReluSquared |
0.246249 | Success |
mlp | True | 33.41K | 2.53 | act_fn=Gelu2 |
0.252107 | Success |
mlp | True | 33.41K | 2.44 | act_fn=BSiLU |
0.211360 | Success |
mlp | True | 33.41K | 2.51 | act_fn=ReluNelu(<br> (forward_fn): ReLU()<br> (backward_fn): NeLU()<br>) |
0.223811 | Success |
mlp | True | 33.41K | 2.38 | act_fn=GELU, residual=True |
0.255551 | Success |
mlp | True | 33.09K | 1.82 | act_fn=GELU, use_norm=False |
0.220111 | Success |
mlp | True | 33.22K | 2.38 | act_fn=GELU, pre_norm=True |
0.221599 | Success |
feedforward | True | 33.09K | 1.76 | glu_variant=none, activation=GELU |
0.207531 | Success |
feedforward | True | 49.73K | 1.90 | glu_variant=glu |
0.179443 | Success |
feedforward | True | 49.73K | 1.87 | glu_variant=swiglu |
0.306741 | Success |
feedforward | True | 49.73K | 1.89 | glu_variant=geglu |
0.304403 | Success |
feedforward | True | 49.73K | 1.82 | glu_variant=reglu |
0.293338 | Success |
feedforward | True | 49.73K | 1.70 | glu_variant=bilinear |
0.318753 | Success |
feedforward | True | 33.34K | 1.76 | glu_variant=mglu |
0.189905 | Success |
feedforward | True | 33.34K | 1.79 | glu_variant=mswiglu |
0.204473 | Success |
feedforward | True | 33.34K | 1.82 | glu_variant=mgeglu |
0.211235 | Success |
feedforward | True | 33.34K | 1.78 | glu_variant=mreglu |
0.213310 | Success |
feedforward | True | 33.34K | 1.73 | glu_variant=mbilinear |
0.236762 | Success |
fastfeedforward | True | 398.28K | 18.49 | glu_variant=swiglu, expert_dim=8 |
0.176578 | Success |
fastfeedforward | True | 398.28K | 18.79 | glu_variant=geglu, expert_dim=8 |
0.176054 | Success |
fastfeedforward | True | 267.21K | 18.12 | glu_variant=mswiglu, expert_dim=8 |
0.173000 | Success |
fastfeedforward | True | 398.28K | 17.40 | glu_variant=swiglu, expert_dim=8 |
0.171305 | Success |
fastfeedforward | True | 398.28K | 17.44 | glu_variant=swiglu, expert_dim=8 |
0.174691 | Success |
fastfeedforward | True | 398.28K | 18.10 | glu_variant=swiglu, expert_dim=8 |
0.172715 | Success |
pathweightedfff | True | 63.38K | 20.21 | depth=3 |
0.246856 | Success |
pathweightedfff | True | 63.38K | 20.27 | depth=3, activation=silu |
0.250302 | Success |
pathweightedfff | True | 266.18K | 25.57 | depth=5 |
0.237995 | Success |
ngpt | True | 49.73K | 2.48 | scalar_alpha=True |
0.005054 | Success |
ngpt | True | 49.79K | 2.40 | scalar_alpha=False |
0.005193 | Success |
gmlp | True | 218.11K | 10.33 | 0.049925 | Success | |
switch_ffn | True | 398.34K | 10.07 | num_experts=8, ff_kwargs={'mult': 4,<br> 'glu_variant': 'swiglu'} |
0.274261 | Success |
switch_ffn | True | 399.36K | 17.40 | num_experts=16, ff_kwargs={'mult': 2,<br> 'glu_variant': 'geglu'} |
0.362581 | Success |
- [] Enhancing Fast Feed Forward Networks with Load Balancing and a Master Leaf Node
- [] limiting soft routing to top-k paths (stochastic top-k or beam) to reduce compute work during training
@article{Zhang2024ReLU2WD,
title = {ReLU2 Wins: Discovering Efficient Activation Functions for Sparse LLMs},
author = {Zhengyan Zhang and Yixin Song and Guanghui Yu and Xu Han and Yankai Lin and Chaojun Xiao and Chenyang Song and Zhiyuan Liu and Zeyu Mi and Maosong Sun},
journal = {ArXiv},
year = {2024},
volume = {abs/2402.03804},
url = {https://api.semanticscholar.org/CorpusID:267499856}
}
@inproceedings{Horuz2025TheRO,
title = {The Resurrection of the ReLU},
author = {Cocsku Can Horuz and Geoffrey Kasenbacher and Saya Higuchi and Sebastian Kairat and Jendrik Stoltz and Moritz Pesl and Bernhard A. Moser and Christoph Linse and Thomas Martinetz and Sebastian Otte},
year = {2025},
url = {https://api.semanticscholar.org/CorpusID:278959515}
}
@misc{shazeer2020glu,
title = {GLU Variants Improve Transformer},
author = {Noam Shazeer},
year = {2020},
url = {https://arxiv.org/abs/2002.05202}
}
@misc{tajima2025maskedgatedlinearunit,
title ={Masked Gated Linear Unit},
author={Yukito Tajima and Nakamasa Inoue and Yusuke Sekikawa and Ikuro Sato and Rio Yokota},
year ={2025},
url ={https://arxiv.org/abs/2506.23225},
}
@misc{loshchilov2025ngptnormalizedtransformerrepresentation,
title={nGPT: Normalized Transformer with Representation Learning on the Hypersphere},
author={Ilya Loshchilov and Cheng-Ping Hsieh and Simeng Sun and Boris Ginsburg},
year={2025},
url={https://arxiv.org/abs/2410.01131},
}
@misc{liu2021payattentionmlps,
title={Pay Attention to MLPs},
author={Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
year={2021},
url={https://arxiv.org/abs/2105.08050},
}
@misc{belcak2023fastfeedforwardnetworks,
title={Fast Feedforward Networks},
author={Peter Belcak and Roger Wattenhofer},
year={2023},
url={https://arxiv.org/abs/2308.14711},
}
@misc{fedus2022switchtransformersscalingtrillion,
title={Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity},
author={William Fedus and Barret Zoph and Noam Shazeer},
year={2022},
eprint={2101.03961},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2101.03961},
}
@misc{belcak2023exponentiallyfasterlanguagemodelling,
title={Exponentially Faster Language Modelling},
author={Peter Belcak and Roger Wattenhofer},
year={2023},
url={https://arxiv.org/abs/2311.10770},
}