This repository contains the official implementation for our paper Compressed and Smooth Latent Space for Text Diffusion Modeling, which was accepted as a poster at NeurIPS 2025.
While autoregressive models dominate text generation, their sequential nature leads to slow decoding and challenges in maintaining global coherence. Diffusion models offer a parallelizable alternative, but their application to text is hindered by the high dimensionality of token-level representations.
We introduce COSMOS, a novel approach that operates entirely in a compressed, smooth latent space. This space is learned using an autoencoder trained for both token-level reconstruction and alignment with a pretrained language encoder, providing robust semantic grounding. Our method allows for an 8x compression of text representations while maintaining high quality, achieving comparable or superior results to strong baselines.
To get started, set up a virtual environment and install the required dependencies using uv:
# Install uv (if not already installed)
curl -LsSf https://astral.sh/uv/install.sh | sh
# Create and activate virtual environment
uv venv .venv
source .venv/bin/activate
# Set the project root directory
export PROJECT_ROOT=$(pwd)
# Install dependencies
uv syncYou will also need to authorize with Weights & Biases for experiment tracking:
wandb loginFor convenience, all datasets used in the paper have been pre-processed and are available on the Hugging Face Hub. We recommend using these pre-processed datasets. You can find all of them here: bayes-group-diffusion/datasets.
The training scripts will automatically download and save them in the data directory. You just need to ensure the dataset in your configuration file (conf/config.yaml) points to the correct dataset. For example, for rocstories:
# in conf/config.yaml
- dataset: "rocstories"# or uv run python -m utils.load_to_hub --config_path ../conf/ --load_from_hub
python -m utils.load_to_hub --config_path ../conf/ --load_from_hubTo use other datasets, update the configuration file accordingly:
# in conf/config.yaml
- dataset: "wikipedia" # or "openwebtext-128", "openwebtext-512"An example of the data preprocessing script is available at utils/owt_preparation.py. The Wikipedia and OpenWebText datasets were prepared using a similar process, mainly differing in the text chunk length.
We provide pretrained model checkpoints on AWS S3. To download them, you will need to have the AWS CLI installed and configured.
First, install the necessary packages to interact with AWS:
pip install boto3 awscliNext, configure your AWS credentials. If you haven't done this before, run the following command and follow the prompts:
aws configure1. Autoencoder
To download the autoencoder, create the destination directory and run the copy command:
mkdir -p ./checkpoints/autoencoder-num_latents=16-wikipedia-final-128/
aws s3 cp s3://cosmos-latent-diffusion/checkpoints/autoencoder-num_latents=16-wikipedia-final-128/100000.pth ./checkpoints/autoencoder-num_latents=16-wikipedia-final-128/100000.pth --region eu-north-1Available checkpoints in S3 cosmos-latent-diffusion/checkpoints:
autoencoder-num_latents=16-wikipedia-final-128/100000.pthautoencoder-num_latents=32-wikipedia-final-128/100000.pthautoencoder-num_latents=64-wikipedia-final-128/100000.pthautoencoder-num_latents=128-wikipedia-final-128/100000.pthautoencoder-num_latents=512-openwebtext-512-final-512/200000.pth
The name of the checkpoint means:
rocstories: dataset namenum_latents=16: number of latentswikipedia: dataset namefinal: final checkpoint128: max sequence length
2. Diffusion Model
To download the diffusion model, create the destination directory and run the copy command:
mkdir -p ./checkpoints/diffusion-rocstories-16-d=5-final/
aws s3 cp s3://cosmos-latent-diffusion/checkpoints/diffusion-rocstories-16-d=5-final/180000.pth ./checkpoints/diffusion-rocstories-16-d=5-final/180000.pth --region eu-north-1Available checkpoints in S3 cosmos-latent-diffusion/checkpoints:
diffusion-rocstories-16-d=5-final/180000.pthdiffusion-rocstories-32-d=5-final/200000.pthdiffusion-rocstories-64-d=7-final/200000.pthdiffusion-openwebtext-512-512-d=5-final-512/500000.pth
The name of the checkpoint means:
rocstories: dataset namenum_latents=16: number of latentsd=5: scheduler parameterfinal: final checkpoint
The training process consists of two main stages: training the autoencoder and training the diffusion model.
Train the autoencoder to learn a compressed latent representation of the text:
HYDRA_FULL_ERROR=1 \
uv run \
torchrun --nproc_per_node=4 --master_port=12346 train_encoder.py \
dataset=wikipedia \
encoder.latent.num_latents=16 \
decoder.latent.num_latents=16 \
encoder.augmentation.masking.weight=0.5 \
encoder.augmentation.masking.encodings_mlm_probability=0.3 \
encoder.augmentation.gaussian_noise.weight=0.5 \
encoder.augmentation.gaussian_noise.delta=0.7 \
encoder.augmentation.latent_masking.probability=0.4 \
autoencoder.latent.dim=768 \
autoencoder.latent.num_latents=16 \
training.training_iters=100000 \
training="autoencoder" \
suffix="final"Once the autoencoder is trained, use its weights to train the diffusion model on the latent space:
CUDA_LAUNCH_BLOCKING=1 \
HYDRA_FULL_ERROR=1 \
uv run \
torchrun --nproc_per_node=4 --master_port=12345 \
train_diffusion.py \
dataset=rocstories \
diffusion.dynamic.N=200 \
diffusion.dynamic.d=5 \
diffusion.training.batch_size=512 \
encoder.latent.num_latents=16 \
encoder.embedding.max_position_embeddings=128 \
decoder.latent.num_latents=16 \
decoder.embedding.max_position_embeddings=128 \
autoencoder.model.load_checkpoint='"autoencoder-num_latents=16-wikipedia-final-128/100000.pth"' \
diffusion.generation.num_gen_texts=2000 \
training=diffusion \
suffix="final"After training the diffusion model, you can generate new text samples:
CUDA_LAUNCH_BLOCKING=1 \
HYDRA_FULL_ERROR=1 \
uv run \
torchrun --nproc_per_node=4 --master_port=12345 \
generate.py \
dataset=rocstories \
diffusion.dynamic.N=200 \
diffusion.dynamic.d=5 \
diffusion.training.batch_size=512 \
encoder.latent.num_latents=16 \
encoder.embedding.max_position_embeddings=128 \
decoder.latent.num_latents=16 \
decoder.embedding.max_position_embeddings=128 \
autoencoder.model.load_checkpoint='"autoencoder-num_latents=16-wikipedia-final-128/100000.pth"' \
diffusion.model.load_checkpoint='"diffusion-rocstories-16-d=5-final/180000.pth"' \
diffusion.generation.num_gen_texts=2000 \
training=""cosmos/
├── 📁 conf/ # Hydra configuration files
├── 📁 estimation/ # Metrics and quality assessment code
├── 📁 utils/ # Data preparation utilities and logging utilities
├── 📁 architecture/ # Model architectures
├── 📁 diffusion_utils/ # Diffusion dynamic, scheduler, and solver
├── 📁 diffusion_trainer.py # Diffusion trainer main class
├── 📁 encoder_trainer.py # Encoder trainer main class
├── 🐍 train_encoder.py # Script for training the autoencoder
├── 🐍 train_diffusion.py # Script for training the diffusion model
└── 🐍 generate.py # Script for text generation
If you use this work, please cite our paper:
If you are interested in collaborating, please reach out to us at meshchaninov.viacheslav@gmail.com or vmeshchani@constructor.university.