This project explores the concept of generating a sequence of "golden noises" for text-to-image diffusion models, inspired by the ideas in "Golden Noise for Diffusion Models: A Learning Framework" (arXiv:2411.09502). Instead of predicting a single optimal noise, this implementation focuses on modeling the evolutionary process where noise is iteratively refined based on a text prompt, using a Recurrent Neural Network (RNN) based architecture (NoiseSequenceRNN_v3).
The core workflow involves:
- Generating a dataset where each sample contains an initial noise (
$x_T$ ) and a sequence of subsequent noises ($[x'_1, ..., x'_n]$ ) obtained through iterative DDIM Denoise/Inversion steps conditioned on a text prompt ($c$ ). - Training an RNN model (
NoiseSequenceRNN_v3) to learn the conditional transition $p_\theta(x'k | x'{k-1}, c)$, predicting the distribution of the next noise state. - Using the trained RNN model during inference to generate a sequence of refined noises and using the final noise (
$\hat{x}'_n$ ) as an improved starting point for a standard diffusion model (e.g., SDXL).
CS5340_PROJECT/
├── data/ # Datasets
│ ├── prompts.txt # Input prompts for dataset generation
│ ├── pickapic_prompts.txt # Input prompts from pickapic dataset for dataset generation
│ ├── pickapic_test_prompts.txt # Prompts for evaluation/testing
│ ├── test_prompts.txt # Prompts for evaluation/testing
│ └── npd_sequence_dataset_sdxl/ # Generated sequence dataset (example)
│ ├── sequences/ # Saved noise sequence files (.pt)
│ └── metadata.csv # Metadata linking prompts and sequence files
├── doc/ # Documentation (optional)
├── inference_output/ # Default output directory for inference images
│ ├── standard_output/ # Images from standard noise
│ └── gnsnet_output/ # Images from golden noise (RNN output)
├── model/ # Model definitions
│ ├── __init__.py
│ └── rnn_seq_model_v3.py # RNN Sequence Model (V3) definition
├── output/ # Default output directory for training checkpoints
│ └── rnn_v3_seq_model_output/ # Example training output dir
├── references/ # Reference papers (optional)
├── results/ # Default output directory for evaluation results
├── scripts/ # Utility and evaluation scripts
│ ├── batch_inference.py # Generate images for multiple prompts
│ ├── evaluate_hps.py # Evaluate generated images using HPSv2
│ ├── evalute.sh # Example evaluation script (shell)
│ ├── extract_prompts.py # Extract prompts from Hugging Face dataset
│ ├── generate_test_prompts.py # (Utility for creating test prompts)
│ ├── inference.sh # Example inference script (shell)
│ └── train.sh # Example training script (shell)
├── src/ # Core source code
│ ├── dataset.py # PyTorch Dataset and DataLoader for sequence data
│ ├── generate_npd_series.py # Script to generate the noise sequence dataset
│ ├── inference_rnn.py # Script to run inference with the RNN model
│ ├── train_rnn_model_v3.py # Script to train the RNN model (V3) with pickapic_prompts
│ └── train_rnn_model_v3.py # Script to train the RNN model (V3)
├── LICENSE # Project License
└── README.md # This file
-
Prerequisites:
- Python 3.8+
- PyTorch (CUDA recommended)
- CUDA Toolkit & compatible NVIDIA driver (if using GPU)
-
Clone Repository:
git clone https://github.com/fangda-ye/CS5340_Project.git cd CS5340_PROJECT -
Create Environment & Install Dependencies: Using Conda is recommended:
conda create -n golden_rnn python=3.8 -c conda-forge -y conda activate golden_rnn # Install PyTorch matching your CUDA version (check PyTorch website) # Example for CUDA 11.8: # conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia # Example for CUDA 12.1: # conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia # Install other dependencies pip install -r requirements.txt
Use the script to download prompts from a dataset like Pick-a-Pic.
python scripts/extract_prompts.py \
--dataset_name yuvalkirstain/pickapic_v1 \
--split train \
--output_dir ./data \
--output_filename prompts.txtThis will save prompts to ./data/prompts.txt. You can also prepare your own .txt file with one prompt per line. Create a separate file (e.g., ./data/test_prompts.txt) for evaluation later.
Use the prepared prompts and a base diffusion model (SDXL recommended) to generate the sequential noise data.
# Ensure you are in the CS5340_PROJECT directory
python src/generate_npd_series.py \
--prompt_file ./data/prompts.txt \
--output_dir ./data/npd_sequence_dataset_sdxl/ \
--num_steps 10 \
--max_prompts 1000 # Adjust as needed for dataset size-
--num_steps: Number of golden noise steps ($n$ ) to generate per prompt. -
--max_prompts: Limits the number of prompts processed (useful for creating smaller test datasets). Remove to process all prompts. - This script saves
_source.ptand_golden_sequence.ptfiles in thesequencessub-directory and creates ametadata.csv. - Warning: This step is computationally intensive and time-consuming.
Train the NoiseSequenceRNN_v3 model on the generated dataset using accelerate.
# Ensure PYTHONPATH includes the project root if running from root
# export PYTHONPATH=. (or use PYTHONPATH=. before accelerate)
accelerate launch src/train_rnn_model_v3.py \
--dataset_dir ./data/npd_sequence_dataset_sdxl/ \
--output_dir ./output/rnn_v3_seq_model_output/ \
--base_model_id stabilityai/stable-diffusion-xl-base-1.0 \
--npnet_model_id SDXL `# For text dim hint` \
--text_embed_dim 1280 `# Adjust if using different base model/embedding` \
--noise_resolution 128 \
--cnn_base_filters 64 \
--cnn_num_blocks 2 2 2 2 \
--cnn_feat_dim 512 \
--gru_hidden_size 1024 \
--gru_num_layers 2 \
--predict_variance `# Add if you want variance prediction` \
--kl_weight 0.01 `# Add if using variance prediction and KL loss` \
--num_epochs 50 \
--batch_size 8 `# Adjust based on GPU memory` \
--gradient_accumulation_steps 4 `# Adjust based on GPU memory` \
--learning_rate 1e-4 \
--mixed_precision fp16 \
--save_steps 1000 \
--max_checkpoints 3 `# Limit disk usage`- Adjust hyperparameters (batch size, learning rate, model dimensions, etc.) based on your resources and dataset.
- The
--text_embed_dimshould match the dimension of the text embedding used (e.g., 1280 for SDXL's pooled CLIP-G). - Checkpoints and logs will be saved in
--output_dir. - Alternatively, you can directly use our pretrained model: https://drive.google.com/file/d/13pFW7f0lR37jenEtuUPu9pYwOBkvY0Li/view?usp=drive_link
Use the trained RNN model to generate an image for a specific prompt.
# Ensure PYTHONPATH includes the project root if running from root
# export PYTHONPATH=.
python src/inference_rnn.py \
--rnn_weights_path ./output/rnn_v3_seq_model_output/rnn_v3_model_final.pth \
--prompt "A futuristic cityscape at sunset, synthwave style" \
--output_dir ./inference_output/ \
--base_model_id stabilityai/stable-diffusion-xl-base-1.0 \
--num_gen_steps 10 `# MUST match dataset num_steps` \
--num_inference_steps 30 \
--guidance_scale 5.5 \
--seed 12345 \
--generate_standard \
--dtype float16 \
# --- Add model config flags matching the trained model ---
# e.g., --predict_variance (if trained with it)
# (Other model dimension args are hardcoded in this script for now)- Replace
--rnn_weights_pathwith the actual path to your trained model. - Set
--num_gen_stepsto the same value used during dataset generation. - The script will save both the standard noise image and the golden noise image (if
--generate_standardis used).
Generate images for all prompts listed in a file.
# Ensure PYTHONPATH includes the project root if running from root
# export PYTHONPATH=.
python scripts/batch_inference.py \
--prompt_file ./data/test_prompts.txt \
--output_base_dir ./inference_output/ \
--rnn_weights_path ./output/rnn_v3_seq_model_output/rnn_v3_model_final.pth \
--base_model_id stabilityai/stable-diffusion-xl-base-1.0 \
--num_gen_steps 10 \
--start_seed 1000 # Use a different starting seed than training/single inference
# --- Add necessary model config flags ---
# --predict_variance- This script reads prompts from
--prompt_file. - It saves standard images to
<output_base_dir>/standard_output/and golden noise images to<output_base_dir>/gnsnet_output/. - Images are named
{index}.pngcorresponding to the line number in the prompt file.
Evaluate the generated image pairs using the HPSv2 score.
# Ensure hpsv2 is installed: pip install hpsv2
python scripts/evaluate_hps.py \
--prompt_file ./data/test_prompts.txt \
--image_base_dir ./inference_output/ \
--results_dir ./results/ \
--hps_version v2.1--prompt_fileshould be the same file used for batch inference.--image_base_dirpoints to the directory containingstandard_outputandgnsnet_output.- Results are saved to
<results_dir>/hpsv2_evaluation.csv.