Official PyTorch Lightning implementation of our paper:
Variance Control via Weight Rescaling in LLM Pretraining
Louis Owen, Abhay Kumar, Nilabhra Roy Chowdhury, Fabian Güra
BluOrion
This paper explores the vital role of initializing and managing the standard deviation of weights during LLM training. Our findings demonstrate that better variance management yields substantial improvements in downstream task performance (up to 4.6%) and reduces extreme activation values, thereby mitigating challenges associated with quantization and low-precision training.
Our code is built within the PyTorch Lightning framework, utilizing its callback system for efficient integration into the training pipeline.
, where l is the layer index. The code for LIR can be found in patches/initialization.py
The code is written as patch, so you need to pass your model to the patch and it will perform inplace operation. If you have other patches, please make sure to apply this patch the last.
from weight_rescaling import re_initialize_model
re_initialize_model(model, # HF class model; LlamaForCausalLM
model_layer_path="model.layers",
initializer_range=0.006,
scale_with_depth=True, #LIR
nlayer_config_name="num_hidden_layers",
)from weight_rescaling import re_initialize_model
re_initialize_model(model, # HF class model; LlamaForCausalLM
model_layer_path="model.layers",
initializer_range=0.02,
scale_with_depth=False,
special_init_residual_module_names=["o_proj","down_proj"],
nlayer_config_name="num_hidden_layers",
)from weight_rescaling import re_initialize_model
re_initialize_model(model, # HF class model; GPT2LMHeadModel
model_layer_path="transformer.h",
initializer_range=0.02,
scale_with_depth=True, #LIR
nlayer_config_name="n_layer",
)The code for TVR can be found in callbacks/tvr.py
The Target Variance Rescaling is implemented as Lightning callback, making it easy to plug into existing training configurations with minimal changes. This modular approach allows for clean separation between the core model architecture and the variance control mechanisms, facilitating experimentation with different rescaling strategies across various model sizes and architectures.
import lightning as pl
from weight_rescaling import TVRCallback
from weight_rescaling.utils.utils import get_layers
layer0 = get_layers(model, fv["model_layer_path"])[0]
valid_2d_module_names = [
name
for name, module in layer0.named_modules()
if hasattr(module, "weight") and len(module.weight.shape) == 2
]
callbacks = []
callbacks.append(
TVRCallback(
valid_2d_module_names=valid_2d_module_names,
additional_module_names_to_log=["input_layernorm","post_attention_layernorm","lm_head","model.norm","model.embed_tokens"],
target_std=0.01,
scale_with_depth=False,
layer_path="model.layers",
step_interval=fv["weight_rescaling_step_interval"],
)
)
# Pass the callback to your Lightning trainer
trainer = pl.Trainer(
...
callbacks=callbacks,
...
)@misc{owen2025variancecontrolweightrescaling,
title={Variance Control via Weight Rescaling in LLM Pre-training},
author={Louis Owen and Abhay Kumar and Nilabhra Roy Chowdhury and Fabian Güra},
year={2025},
eprint={2503.17500},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2503.17500},
}


