diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index e0b7af4898b2..743425384776 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -500,6 +500,8 @@ title: AuraFlow - local: api/pipelines/blip_diffusion title: BLIP-Diffusion + - local: api/pipelines/block_refinement + title: Block Refinement - local: api/pipelines/bria_3_2 title: Bria 3.2 - local: api/pipelines/bria_fibo @@ -578,6 +580,8 @@ title: Latent Diffusion - local: api/pipelines/ledits_pp title: LEDITS++ + - local: api/pipelines/llada2 + title: LLaDA2 - local: api/pipelines/longcat_image title: LongCat-Image - local: api/pipelines/lumina2 @@ -714,6 +718,8 @@ - sections: - local: api/schedulers/overview title: Overview + - local: api/schedulers/block_refinement + title: BlockRefinementScheduler - local: api/schedulers/cm_stochastic_iterative title: CMStochasticIterativeScheduler - local: api/schedulers/ddim_cogvideox diff --git a/docs/source/en/api/pipelines/block_refinement.md b/docs/source/en/api/pipelines/block_refinement.md new file mode 100644 index 000000000000..d3f313f6ff29 --- /dev/null +++ b/docs/source/en/api/pipelines/block_refinement.md @@ -0,0 +1,60 @@ + + +# Block Refinement + +`BlockRefinementPipeline` performs block-wise iterative refinement over a masked token template, sampling and +committing tokens based on confidence. + +## Config defaults + +You can set default sampling parameters when creating the pipeline. Passing `None` for a parameter in `__call__` +falls back to `pipe.config`. + +```py +from diffusers import BlockRefinementPipeline, BlockRefinementScheduler + +scheduler = BlockRefinementScheduler() +pipe = BlockRefinementPipeline( + model=model, + scheduler=scheduler, + tokenizer=tokenizer, +) + +out = pipe(prompt="Explain gradient descent.", gen_length=256, block_length=32, steps=16, temperature=0.8) +print(out.texts[0]) +``` + +## Callbacks + +Callbacks run after each refinement step and can inspect or override the current tokens. + +```py +def on_step_end(pipe, step, timestep, callback_kwargs): + cur_x = callback_kwargs["cur_x"] + # Inspect or modify `cur_x` here. + return {"cur_x": cur_x} + +out = pipe( + prompt="Write a short poem.", + callback_on_step_end=on_step_end, + callback_on_step_end_tensor_inputs=["cur_x"], +) +``` + +## BlockRefinementPipeline +[[autodoc]] BlockRefinementPipeline + - all + - __call__ + +## BlockRefinementPipelineOutput +[[autodoc]] pipelines.BlockRefinementPipelineOutput diff --git a/docs/source/en/api/pipelines/llada2.md b/docs/source/en/api/pipelines/llada2.md new file mode 100644 index 000000000000..ba9330e4f5b3 --- /dev/null +++ b/docs/source/en/api/pipelines/llada2.md @@ -0,0 +1,23 @@ + + +# LLaDA2 + +`LLaDA2Pipeline` adapts block refinement sampling for LLaDA2-style token diffusion models. + +## LLaDA2Pipeline +[[autodoc]] LLaDA2Pipeline + - all + - __call__ + +## LLaDA2PipelineOutput +[[autodoc]] pipelines.LLaDA2PipelineOutput diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md index 22fcf560eaca..e7076ab6732b 100644 --- a/docs/source/en/api/pipelines/overview.md +++ b/docs/source/en/api/pipelines/overview.md @@ -34,6 +34,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an | [AudioLDM2](audioldm2) | text2audio | | [AuraFlow](aura_flow) | text2image | | [BLIP Diffusion](blip_diffusion) | text2image | +| [Block Refinement](block_refinement) | text2text | | [Bria 3.2](bria_3_2) | text2image | | [CogVideoX](cogvideox) | text2video | | [Consistency Models](consistency_models) | unconditional image generation | @@ -62,6 +63,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an | [Latent Diffusion](latent_diffusion) | text2image, super-resolution | | [Latte](latte) | text2image | | [LEDITS++](ledits_pp) | image editing | +| [LLaDA2](llada2) | text2text | | [Lumina-T2X](lumina) | text2image | | [Marigold](marigold) | depth-estimation, normals-estimation, intrinsic-decomposition | | [MultiDiffusion](panorama) | text2image | diff --git a/docs/source/en/api/schedulers/block_refinement.md b/docs/source/en/api/schedulers/block_refinement.md new file mode 100644 index 000000000000..da6209f61242 --- /dev/null +++ b/docs/source/en/api/schedulers/block_refinement.md @@ -0,0 +1,25 @@ + + +# BlockRefinementScheduler + +The `BlockRefinementScheduler` manages block-wise iterative refinement for discrete token diffusion. At each step it +commits the most confident tokens and optionally edits already-committed tokens when the model predicts a different +token with high confidence. + +This scheduler is used by [`BlockRefinementPipeline`] and [`LLaDA2Pipeline`]. + +## BlockRefinementScheduler +[[autodoc]] BlockRefinementScheduler + +## BlockRefinementSchedulerOutput +[[autodoc]] schedulers.scheduling_block_refinement.BlockRefinementSchedulerOutput diff --git a/examples/discrete_diffusion/README.md b/examples/discrete_diffusion/README.md new file mode 100644 index 000000000000..da41c6040c49 --- /dev/null +++ b/examples/discrete_diffusion/README.md @@ -0,0 +1,72 @@ +# Discrete Token Diffusion (Experimental) + +This folder contains **training and sampling examples** for *discrete diffusion over token IDs* (language-model style), built to follow the `diffusers` + `accelerate` training conventions. + +## Block refinement (commit-by-confidence) + +Block refinement iteratively generates text in fixed-size blocks. At each step the model predicts all tokens in the block, commits the most confident ones, and re-masks the rest for further refinement. + +### Train (Qwen causal LM) + +```bash +accelerate launch examples/discrete_diffusion/train_block_refinement_qwen_cap.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --text_column text \ + --output_dir qwen-block-refinement-output \ + --max_train_steps 1000 \ + --prompt_length 32 \ + --block_length 32 \ + --lambda_conf 2.0 \ + --conf_temperature 0.5 +``` + +If you don't want to download a dataset, you can use random-token data: + +```bash +accelerate launch examples/discrete_diffusion/train_block_refinement_qwen_cap.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --output_dir qwen-block-refinement-output \ + --use_dummy_data \ + --num_dummy_samples 2048 +``` + +### Sample + +```bash +python examples/discrete_diffusion/sample_block_refinement.py \ + --checkpoint_path qwen-block-refinement-output/final \ + --device cuda \ + --attention_mask_mode 2d \ + --prompt "Write a short paragraph about diffusion models." \ + --gen_length 128 +``` + +For causal LMs that only support a 2D `attention_mask`, use `--attention_mask_mode 2d`. + +## LLaDA2 sampling + +[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) uses block refinement with a masked language model backbone. The `LLaDA2Pipeline` wraps `BlockRefinementPipeline` with LLaDA2-specific defaults. + +```bash +python examples/discrete_diffusion/sample_llada2.py \ + --model_id inclusionAI/LLaDA-8B-Instruct \ + --prompt "Write a short poem about the ocean." \ + --gen_length 128 \ + --steps 128 +``` + +### LLaDA2.1 editing support + +LLaDA2.1 models support post-mask token editing via `--editing_threshold`: + +```bash +python examples/discrete_diffusion/sample_llada2.py \ + --model_id inclusionAI/LLaDA2.1-8B-Instruct \ + --prompt "Explain quantum computing in simple terms." \ + --gen_length 256 \ + --steps 256 \ + --editing_threshold 0.4 \ + --max_post_steps 2 +``` diff --git a/examples/discrete_diffusion/sample_block_refinement.py b/examples/discrete_diffusion/sample_block_refinement.py new file mode 100644 index 000000000000..e96fb7f88f0d --- /dev/null +++ b/examples/discrete_diffusion/sample_block_refinement.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python + +import argparse + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from diffusers import BlockRefinementPipeline, BlockRefinementScheduler + + +def main(): + parser = argparse.ArgumentParser(description="Sample with BlockRefinementPipeline using a transformers causal LM.") + parser.add_argument("--checkpoint_path", type=str, required=True) + parser.add_argument("--cache_dir", type=str, default=None) + parser.add_argument("--prompt", type=str, default="Write a short paragraph about diffusion models.") + parser.add_argument("--gen_length", type=int, default=128) + parser.add_argument("--block_length", type=int, default=32) + parser.add_argument("--steps", type=int, default=32) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--top_p", type=float, default=1.0) + parser.add_argument("--top_k", type=int, default=0) + parser.add_argument("--threshold", type=float, default=0.95) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--attention_mask_mode", type=str, default="2d", choices=["auto", "4d", "2d", "none"]) + + args = parser.parse_args() + + tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, use_fast=True, cache_dir=args.cache_dir) + model = AutoModelForCausalLM.from_pretrained( + args.checkpoint_path, + torch_dtype=torch.bfloat16 if args.device.startswith("cuda") else torch.float32, + cache_dir=args.cache_dir, + ) + model.to(args.device) + model.eval() + + if tokenizer.mask_token_id is None: + raise ValueError("Tokenizer must have `mask_token_id` for block refinement sampling.") + + scheduler = BlockRefinementScheduler() + pipe = BlockRefinementPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer).to(args.device) + gen = torch.Generator(device=args.device).manual_seed(args.seed) + + prompt_ids = tokenizer(args.prompt, return_tensors="pt")["input_ids"].to(args.device) + out = pipe( + prompt_ids=prompt_ids, + gen_length=int(args.gen_length), + block_length=int(args.block_length), + steps=int(args.steps), + temperature=float(args.temperature), + top_p=None if args.top_p >= 1.0 else float(args.top_p), + top_k=None if args.top_k <= 0 else int(args.top_k), + threshold=float(args.threshold), + eos_early_stop=True, + eos_token_id=int(tokenizer.eos_token_id) if tokenizer.eos_token_id is not None else None, + mask_token_id=int(tokenizer.mask_token_id), + attention_mask_mode=args.attention_mask_mode, + generator=gen, + return_text=True, + ) + + print(out.texts[0] if out.texts is not None else tokenizer.decode(out.sequences[0], skip_special_tokens=True)) + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/sample_llada2.py b/examples/discrete_diffusion/sample_llada2.py new file mode 100644 index 000000000000..59b3806daf53 --- /dev/null +++ b/examples/discrete_diffusion/sample_llada2.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sample script for LLaDA2-style discrete diffusion text generation. + +This script demonstrates how to use the LLaDA2Pipeline for text generation +using block-wise iterative refinement. + +Example usage: + python sample_llada2.py --model_id inclusionAI/LLaDA2.0-mini --prompt "What is the capital of France?" + python sample_llada2.py --model_id inclusionAI/LLaDA2.0-flash-CAP --prompt "Explain quantum computing." --temperature 0.7 +""" + +import argparse + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from diffusers import BlockRefinementScheduler, LLaDA2Pipeline +from diffusers.hooks import apply_group_offloading + + +def main(): + parser = argparse.ArgumentParser( + description="Generate text using LLaDA2Pipeline with block-wise discrete diffusion." + ) + parser.add_argument( + "--model_id", + type=str, + default="inclusionAI/LLaDA2.0-mini", + help="HuggingFace model ID or path to local model.", + ) + parser.add_argument( + "--prompt", + type=str, + default="Why does Camus think that Sisyphus is happy?", + help="Text prompt to generate from.", + ) + parser.add_argument( + "--gen_length", + type=int, + default=2048, + help="Number of tokens to generate.", + ) + parser.add_argument( + "--block_length", + type=int, + default=32, + help="Size of each generation block.", + ) + parser.add_argument( + "--steps", + type=int, + default=32, + help="Number of refinement steps per block.", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Sampling temperature (0.0 for greedy).", + ) + parser.add_argument( + "--top_p", + type=float, + default=None, + help="Nucleus sampling probability threshold.", + ) + parser.add_argument( + "--top_k", + type=int, + default=None, + help="Top-k sampling parameter.", + ) + parser.add_argument( + "--threshold", + type=float, + default=0.95, + help="Confidence threshold for committing tokens.", + ) + parser.add_argument( + "--editing_threshold", + type=float, + default=None, + help="Confidence threshold for editing already-committed tokens. Set to enable post-mask editing (e.g. 0.5).", + ) + parser.add_argument( + "--max_post_steps", + type=int, + default=0, + help="Maximum post-mask editing iterations per block (e.g. 16). Only used when --editing_threshold is set.", + ) + parser.add_argument( + "--sampling_method", + type=str, + default="multinomial", + choices=["auto", "greedy", "multinomial"], + help="Sampling method for block refinement.", + ) + parser.add_argument( + "--eos_early_stop", + action="store_true", + help="Stop generation early when EOS token is generated.", + ) + parser.add_argument( + "--use_chat_template", + action="store_true", + help="Use the tokenizer chat template for the prompt.", + ) + parser.add_argument( + "--add_generation_prompt", + action="store_true", + help="Add the generation prompt when using the chat template.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run inference on.", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["float32", "float16", "bfloat16"], + help="Model dtype.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Random seed for reproducibility.", + ) + parser.add_argument( + "--offload", + type=str, + default=None, + choices=["group", "sequential"], + help="Memory offloading strategy: 'group' for group offloading (faster), 'sequential' for sequential CPU offload (slower but lower memory).", + ) + + args = parser.parse_args() + + # Parse dtype + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + torch_dtype = dtype_map[args.dtype] + + print(f"Loading model: {args.model_id}") + tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) + + # Load model with appropriate memory settings based on offload strategy + if args.offload == "group": + # For group offloading, load to CPU first then apply hooks + print("Using group offloading for memory efficiency...") + model = AutoModelForCausalLM.from_pretrained( + args.model_id, + trust_remote_code=True, + torch_dtype=torch_dtype, + low_cpu_mem_usage=True, + ) + # Apply group offloading with CUDA streams for better performance + onload_device = torch.device(args.device) + offload_device = torch.device("cpu") + apply_group_offloading( + model, + onload_device=onload_device, + offload_device=offload_device, + offload_type="leaf_level", + use_stream=True, + ) + elif args.offload == "sequential": + # For sequential offloading, load to CPU first + print("Using sequential CPU offloading (slower but lower memory)...") + model = AutoModelForCausalLM.from_pretrained( + args.model_id, + trust_remote_code=True, + torch_dtype=torch_dtype, + low_cpu_mem_usage=True, + ) + # Sequential offloading will be applied via pipeline + else: + # Default: use device_map="auto" for automatic memory management + model = AutoModelForCausalLM.from_pretrained( + args.model_id, + trust_remote_code=True, + torch_dtype=torch_dtype, + device_map="auto", + low_cpu_mem_usage=True, + ) + model.eval() + + # Create pipeline + scheduler = BlockRefinementScheduler() + pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer) + + # Apply sequential CPU offload if requested + if args.offload == "sequential": + pipe.enable_sequential_cpu_offload() + + # Set up generator for reproducibility + generator = None + if args.seed is not None: + generator = torch.Generator(device=args.device).manual_seed(args.seed) + + print(f"\nPrompt: {args.prompt}") + print(f"Generating {args.gen_length} tokens with block_length={args.block_length}, steps={args.steps}") + print("-" * 50) + + # Generate + output = pipe( + prompt=args.prompt, + use_chat_template=args.use_chat_template, + add_generation_prompt=args.add_generation_prompt, + gen_length=args.gen_length, + block_length=args.block_length, + steps=args.steps, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + threshold=args.threshold, + editing_threshold=args.editing_threshold, + max_post_steps=args.max_post_steps, + sampling_method=args.sampling_method, + eos_early_stop=args.eos_early_stop, + generator=generator, + ) + + print("\nGenerated text:") + print(output.texts[0]) + + print(f"\nGenerated {output.sequences.shape[1]} tokens") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/train_block_refinement_cap.py b/examples/discrete_diffusion/train_block_refinement_cap.py new file mode 100644 index 000000000000..061c9900303f --- /dev/null +++ b/examples/discrete_diffusion/train_block_refinement_cap.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import math +import os +from dataclasses import asdict, dataclass +from typing import Optional, Tuple + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from torch.utils.data import DataLoader, Dataset + +from diffusers import BlockRefinementPipeline, BlockRefinementScheduler +from diffusers.training_utils import compute_confidence_aware_loss + + +logger = get_logger(__name__) + + +@dataclass +class TrainConfig: + output_dir: str + seed: int + max_train_steps: int + logging_steps: int + checkpointing_steps: int + + per_device_train_batch_size: int + gradient_accumulation_steps: int + learning_rate: float + weight_decay: float + + vocab_size: int + mask_token_id: int + eos_token_id: int + max_length: int + prompt_length: int + + block_length: int + steps: int + lambda_conf: float + conf_temperature: float + temperature: float + threshold: float + + +def parse_args() -> TrainConfig: + parser = argparse.ArgumentParser( + description="Train a block-wise refinement model with a confidence-aware objective (CAP-style)." + ) + + parser.add_argument("--output_dir", type=str, default="block-refinement-output") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--max_train_steps", type=int, default=1000) + parser.add_argument("--logging_steps", type=int, default=50) + parser.add_argument("--checkpointing_steps", type=int, default=500) + + parser.add_argument("--per_device_train_batch_size", type=int, default=64) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=2e-4) + parser.add_argument("--weight_decay", type=float, default=0.0) + + parser.add_argument("--vocab_size", type=int, default=256) + parser.add_argument("--mask_token_id", type=int, default=255) + parser.add_argument("--eos_token_id", type=int, default=254) + parser.add_argument("--max_length", type=int, default=64) + parser.add_argument("--prompt_length", type=int, default=8) + + parser.add_argument("--block_length", type=int, default=16) + parser.add_argument("--steps", type=int, default=16) + parser.add_argument("--lambda_conf", type=float, default=2.0) + parser.add_argument("--conf_temperature", type=float, default=0.5) + + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--threshold", type=float, default=0.95) + + args = parser.parse_args() + return TrainConfig(**vars(args)) + + +def build_block_attention_mask( + *, + num_blocks: int, + block_length: int, + total_length: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + block_mask = torch.tril(torch.ones(num_blocks, num_blocks, device=device, dtype=torch.bool)) + attn = ( + block_mask.repeat_interleave(block_length, dim=0) + .repeat_interleave(block_length, dim=1) + .unsqueeze(0) + .unsqueeze(0) + ) + attn = attn[:, :, :total_length, :total_length] + return torch.where( + attn, torch.zeros((), device=device, dtype=dtype), torch.full((), float("-inf"), device=device, dtype=dtype) + ) + + +def forward_process_semi_ar( + input_ids: torch.LongTensor, + *, + prompt_length: int, + block_length: int, + mask_token_id: int, + generator: Optional[torch.Generator], +) -> Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]: + batch_size, seq_len = input_ids.shape + device = input_ids.device + + noisy = input_ids.clone() + noisy_rev = input_ids.clone() + masked = torch.zeros_like(input_ids, dtype=torch.bool) + masked_rev = torch.zeros_like(input_ids, dtype=torch.bool) + + start = int(prompt_length) + for block_start in range(start, seq_len, int(block_length)): + block_end = min(seq_len, block_start + int(block_length)) + seg_len = block_end - block_start + if seg_len <= 0: + continue + + p_mask = torch.rand((batch_size, 1), device=device, generator=generator) + seg = torch.rand((batch_size, seg_len), device=device, generator=generator) < p_mask + seg_rev = ~seg + + masked[:, block_start:block_end] = seg + masked_rev[:, block_start:block_end] = seg_rev + + noisy = torch.where(masked, torch.full_like(noisy, int(mask_token_id)), noisy) + noisy_rev = torch.where(masked_rev, torch.full_like(noisy_rev, int(mask_token_id)), noisy_rev) + return noisy, noisy_rev, masked, masked_rev + + +class RandomTokenDataset(Dataset): + def __init__(self, *, num_samples: int, seq_len: int, vocab_size: int, eos_token_id: int): + self.num_samples = int(num_samples) + self.seq_len = int(seq_len) + self.vocab_size = int(vocab_size) + self.eos_token_id = int(eos_token_id) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + del idx + # Keep EOS out of the training distribution to avoid trivial early-stops during sampling. + ids = torch.randint(0, self.vocab_size - 2, (self.seq_len,), dtype=torch.long) + return {"input_ids": ids} + + +class TinyBlockRefinementLM(torch.nn.Module): + def __init__(self, *, vocab_size: int, hidden_size: int = 128, num_heads: int = 4, num_layers: int = 4): + super().__init__() + self.vocab_size = int(vocab_size) + self.hidden_size = int(hidden_size) + + self.token_emb = torch.nn.Embedding(self.vocab_size, self.hidden_size) + self.pos_emb = torch.nn.Embedding(2048, self.hidden_size) + enc_layer = torch.nn.TransformerEncoderLayer( + d_model=self.hidden_size, + nhead=int(num_heads), + dim_feedforward=self.hidden_size * 4, + dropout=0.0, + activation="gelu", + batch_first=True, + norm_first=True, + ) + self.encoder = torch.nn.TransformerEncoder(enc_layer, num_layers=int(num_layers)) + self.lm_head = torch.nn.Linear(self.hidden_size, self.vocab_size, bias=False) + + @property + def dtype(self): + return next(self.parameters()).dtype + + def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs): + if position_ids is None: + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand_as(input_ids) + + x = self.token_emb(input_ids) + self.pos_emb(position_ids) + + attn_mask = None + if attention_mask is not None: + if attention_mask.ndim == 4: + attn_mask = attention_mask[0, 0] + elif attention_mask.ndim == 2: + attn_mask = attention_mask + else: + raise ValueError(f"Unsupported `attention_mask` shape: {attention_mask.shape}") + attn_mask = attn_mask.to(dtype=torch.float32) + + hidden = self.encoder(x, mask=attn_mask) + logits = self.lm_head(hidden) + return type("Output", (), {"logits": logits}) + + +def save_checkpoint(output_dir: str, *, model: torch.nn.Module, cfg: TrainConfig): + os.makedirs(output_dir, exist_ok=True) + torch.save(model.state_dict(), os.path.join(output_dir, "pytorch_model.bin")) + with open(os.path.join(output_dir, "training_config.json"), "w", encoding="utf-8") as f: + json.dump(asdict(cfg), f, indent=2, sort_keys=True) + + +def main(): + cfg = parse_args() + if cfg.mask_token_id >= cfg.vocab_size: + raise ValueError("`mask_token_id` must be < `vocab_size`.") + if cfg.eos_token_id >= cfg.vocab_size: + raise ValueError("`eos_token_id` must be < `vocab_size`.") + if cfg.prompt_length >= cfg.max_length: + raise ValueError("`prompt_length` must be < `max_length`.") + + project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs")) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + project_config=project_config, + ) + if accelerator.is_main_process: + os.makedirs(cfg.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + set_seed(cfg.seed) + logger.info("Training configuration: %s", asdict(cfg)) + + dataset = RandomTokenDataset( + num_samples=max(cfg.max_train_steps * cfg.per_device_train_batch_size, 4096), + seq_len=cfg.max_length, + vocab_size=cfg.vocab_size, + eos_token_id=cfg.eos_token_id, + ) + dataloader = DataLoader(dataset, batch_size=cfg.per_device_train_batch_size, shuffle=True, drop_last=True) + + model = TinyBlockRefinementLM(vocab_size=cfg.vocab_size) + scheduler = BlockRefinementScheduler() + pipe = BlockRefinementPipeline(model=model, scheduler=scheduler, tokenizer=None) + + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) + + num_update_steps_per_epoch = math.ceil(len(dataloader) / cfg.gradient_accumulation_steps) + num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) + + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + pipe = pipe.to(accelerator.device) + + global_step = 0 + model.train() + + for _epoch in range(num_train_epochs): + for batch in dataloader: + with accelerator.accumulate(model): + input_ids = batch["input_ids"] + + # Build the same attention mask that the sampler uses. + prompt_len = int(cfg.prompt_length) + num_blocks = (prompt_len + int(cfg.max_length - prompt_len) + int(cfg.block_length) - 1) // int( + cfg.block_length + ) + total_length = int(num_blocks) * int(cfg.block_length) + total_length = max(total_length, int(cfg.max_length)) + attn_mask = build_block_attention_mask( + num_blocks=(total_length + int(cfg.block_length) - 1) // int(cfg.block_length), + block_length=int(cfg.block_length), + total_length=int(cfg.max_length), + device=input_ids.device, + dtype=torch.bfloat16 if input_ids.device.type == "cuda" else torch.float32, + ) + position_ids = ( + torch.arange(int(cfg.max_length), device=input_ids.device, dtype=torch.long) + .unsqueeze(0) + .expand_as(input_ids) + ) + + gen = None + if accelerator.is_local_main_process: + gen = torch.Generator(device=input_ids.device).manual_seed(cfg.seed + global_step) + + noisy, noisy_rev, masked, masked_rev = forward_process_semi_ar( + input_ids, + prompt_length=prompt_len, + block_length=int(cfg.block_length), + mask_token_id=int(cfg.mask_token_id), + generator=gen, + ) + + logits = model(noisy, attention_mask=attn_mask, position_ids=position_ids).logits + logits_rev = model(noisy_rev, attention_mask=attn_mask, position_ids=position_ids).logits + + # Do not allow predicting mask_id. + logits = logits.clone() + logits[..., int(cfg.mask_token_id)] = torch.finfo(logits.dtype).min + logits_rev = logits_rev.clone() + logits_rev[..., int(cfg.mask_token_id)] = torch.finfo(logits_rev.dtype).min + + labels = input_ids.clone() + labels[~masked] = -100 + labels_rev = input_ids.clone() + labels_rev[~masked_rev] = -100 + + weights = masked.to(dtype=logits.dtype) + weights_rev = masked_rev.to(dtype=logits.dtype) + + loss, loss_sft, loss_conf = compute_confidence_aware_loss( + logits, + labels, + lambda_conf=cfg.lambda_conf, + temperature=cfg.conf_temperature, + per_token_weights=weights, + ) + loss_rev, loss_sft_rev, loss_conf_rev = compute_confidence_aware_loss( + logits_rev, + labels_rev, + lambda_conf=cfg.lambda_conf, + temperature=cfg.conf_temperature, + per_token_weights=weights_rev, + ) + + total_loss = loss + loss_rev + accelerator.backward(total_loss) + optimizer.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + global_step += 1 + + if global_step % cfg.logging_steps == 0 and accelerator.is_main_process: + logger.info( + "step=%d loss=%.4f sft=%.4f conf=%.4f", + global_step, + total_loss.item(), + (loss_sft + loss_sft_rev).item(), + (loss_conf + loss_conf_rev).item(), + ) + + if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}") + save_checkpoint(save_dir, model=accelerator.unwrap_model(model), cfg=cfg) + + if global_step >= cfg.max_train_steps: + break + + if global_step >= cfg.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + final_dir = os.path.join(cfg.output_dir, "final") + save_checkpoint(final_dir, model=accelerator.unwrap_model(model), cfg=cfg) + + # Quick sampler smoke to ensure the pipeline runs with the trained weights. + out = pipe( + prompt_ids=torch.randint(0, cfg.vocab_size - 2, (1, cfg.prompt_length), device=accelerator.device), + gen_length=int(cfg.max_length - cfg.prompt_length), + block_length=int(cfg.block_length), + steps=int(cfg.steps), + temperature=float(cfg.temperature), + threshold=float(cfg.threshold), + eos_early_stop=False, + eos_token_id=int(cfg.eos_token_id), + mask_token_id=int(cfg.mask_token_id), + return_text=False, + ) + logger.info("sample shape=%s", tuple(out.sequences.shape)) + + logger.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/train_block_refinement_qwen_cap.py b/examples/discrete_diffusion/train_block_refinement_qwen_cap.py new file mode 100644 index 000000000000..5149f3ba61d0 --- /dev/null +++ b/examples/discrete_diffusion/train_block_refinement_qwen_cap.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import os +from dataclasses import asdict, dataclass +from typing import Dict, Optional, Tuple + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, get_scheduler + +from diffusers.training_utils import compute_confidence_aware_loss + + +logger = get_logger(__name__) + + +@dataclass +class TrainConfig: + model_name_or_path: str + dataset_name: str + dataset_config_name: Optional[str] + text_column: str + cache_dir: Optional[str] + use_dummy_data: bool + num_dummy_samples: int + + output_dir: str + seed: int + max_train_steps: int + checkpointing_steps: int + logging_steps: int + + per_device_train_batch_size: int + gradient_accumulation_steps: int + learning_rate: float + weight_decay: float + lr_scheduler: str + lr_warmup_steps: int + + max_length: int + prompt_length: int + block_length: int + + lambda_conf: float + conf_temperature: float + + +def parse_args() -> TrainConfig: + parser = argparse.ArgumentParser(description="Train block-refinement with a confidence-aware loss on a causal LM.") + + parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2.5-0.5B") + parser.add_argument("--dataset_name", type=str, default="wikitext") + parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1") + parser.add_argument("--text_column", type=str, default="text") + parser.add_argument("--cache_dir", type=str, default=None) + parser.add_argument("--use_dummy_data", action="store_true", help="Use random-token data instead of downloading.") + parser.add_argument("--num_dummy_samples", type=int, default=2048) + + parser.add_argument("--output_dir", type=str, default="qwen-block-refinement-output") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--max_train_steps", type=int, default=1000) + parser.add_argument("--checkpointing_steps", type=int, default=500) + parser.add_argument("--logging_steps", type=int, default=50) + + parser.add_argument("--per_device_train_batch_size", type=int, default=1) + parser.add_argument("--gradient_accumulation_steps", type=int, default=8) + parser.add_argument("--learning_rate", type=float, default=2e-5) + parser.add_argument("--weight_decay", type=float, default=0.0) + parser.add_argument( + "--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"] + ) + parser.add_argument("--lr_warmup_steps", type=int, default=100) + + parser.add_argument("--max_length", type=int, default=256) + parser.add_argument("--prompt_length", type=int, default=32) + parser.add_argument("--block_length", type=int, default=32) + + parser.add_argument("--lambda_conf", type=float, default=2.0) + parser.add_argument("--conf_temperature", type=float, default=0.5) + + args = parser.parse_args() + return TrainConfig(**vars(args)) + + +def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int): + texts = examples[text_column] + texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0] + return tokenizer(texts, truncation=True, padding=False, max_length=max_length) + + +class RandomTokenDataset(torch.utils.data.Dataset): + def __init__(self, *, num_samples: int, seq_len: int, vocab_size: int, pad_token_id: int): + self.num_samples = int(num_samples) + self.seq_len = int(seq_len) + self.vocab_size = int(vocab_size) + self.pad_token_id = int(pad_token_id) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + del idx + input_ids = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +def forward_process_semi_ar( + input_ids: torch.LongTensor, + attention_mask: torch.LongTensor, + *, + prompt_length: int, + block_length: int, + mask_token_id: int, + generator: Optional[torch.Generator], +) -> Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]: + batch_size, seq_len = input_ids.shape + device = input_ids.device + + noisy = input_ids.clone() + noisy_rev = input_ids.clone() + masked = torch.zeros_like(input_ids, dtype=torch.bool) + masked_rev = torch.zeros_like(input_ids, dtype=torch.bool) + + # Only mask non-padding positions after the prompt. + valid = attention_mask.to(dtype=torch.bool) + start = int(prompt_length) + for block_start in range(start, seq_len, int(block_length)): + block_end = min(seq_len, block_start + int(block_length)) + seg_len = block_end - block_start + if seg_len <= 0: + continue + + p_mask = torch.rand((batch_size, 1), device=device, generator=generator) + seg = torch.rand((batch_size, seg_len), device=device, generator=generator) < p_mask + seg = seg & valid[:, block_start:block_end] + seg_rev = (~seg) & valid[:, block_start:block_end] + + masked[:, block_start:block_end] = seg + masked_rev[:, block_start:block_end] = seg_rev + + noisy = torch.where(masked, torch.full_like(noisy, int(mask_token_id)), noisy) + noisy_rev = torch.where(masked_rev, torch.full_like(noisy_rev, int(mask_token_id)), noisy_rev) + return noisy, noisy_rev, masked, masked_rev + + +def main(): + cfg = parse_args() + if cfg.prompt_length >= cfg.max_length: + raise ValueError("`prompt_length` must be < `max_length`.") + if cfg.block_length <= 0: + raise ValueError("`block_length` must be > 0.") + + project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs")) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + project_config=project_config, + ) + if accelerator.is_main_process: + os.makedirs(cfg.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + set_seed(cfg.seed) + logger.info("Training configuration: %s", asdict(cfg)) + + tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True, cache_dir=cfg.cache_dir) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + if tokenizer.mask_token_id is None: + tokenizer.add_special_tokens({"mask_token": "[MASK]"}) + + load_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 + model = AutoModelForCausalLM.from_pretrained( + cfg.model_name_or_path, cache_dir=cfg.cache_dir, torch_dtype=load_dtype + ) + model.resize_token_embeddings(len(tokenizer)) + if load_dtype == torch.float32: + model.to(dtype=torch.float32) + + mask_token_id = int(tokenizer.mask_token_id) + + if cfg.use_dummy_data: + dataset = RandomTokenDataset( + num_samples=cfg.num_dummy_samples, + seq_len=cfg.max_length, + vocab_size=len(tokenizer), + pad_token_id=int(tokenizer.pad_token_id), + ) + train_dataloader = DataLoader( + dataset, + shuffle=True, + batch_size=cfg.per_device_train_batch_size, + drop_last=True, + ) + else: + raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name, cache_dir=cfg.cache_dir) + if "train" not in raw_datasets: + raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.") + + with accelerator.main_process_first(): + tokenized = raw_datasets["train"].map( + lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length), + batched=True, + remove_columns=raw_datasets["train"].column_names, + desc="Tokenizing", + ) + + collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt") + train_dataloader = DataLoader( + tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) + num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=cfg.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=cfg.max_train_steps, + ) + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + global_step = 0 + model.train() + + for _epoch in range(num_train_epochs): + for batch in train_dataloader: + with accelerator.accumulate(model): + input_ids = batch["input_ids"] + attention_mask = batch.get("attention_mask", torch.ones_like(input_ids)) + + gen = torch.Generator(device=input_ids.device).manual_seed(cfg.seed + global_step) + noisy, noisy_rev, masked, masked_rev = forward_process_semi_ar( + input_ids, + attention_mask, + prompt_length=int(cfg.prompt_length), + block_length=int(cfg.block_length), + mask_token_id=mask_token_id, + generator=gen, + ) + + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand_as(input_ids) + ) + + logits = model(input_ids=noisy, attention_mask=attention_mask, position_ids=position_ids).logits + logits_rev = model( + input_ids=noisy_rev, attention_mask=attention_mask, position_ids=position_ids + ).logits + + logits = logits.clone() + logits[..., mask_token_id] = torch.finfo(logits.dtype).min + logits_rev = logits_rev.clone() + logits_rev[..., mask_token_id] = torch.finfo(logits_rev.dtype).min + + valid = attention_mask.to(dtype=torch.bool) + masked = masked & valid + masked_rev = masked_rev & valid + + labels = input_ids.clone() + labels[~masked] = -100 + labels_rev = input_ids.clone() + labels_rev[~masked_rev] = -100 + + weights = masked.to(dtype=logits.dtype) + weights_rev = masked_rev.to(dtype=logits.dtype) + + loss, loss_sft, loss_conf = compute_confidence_aware_loss( + logits, + labels, + lambda_conf=cfg.lambda_conf, + temperature=cfg.conf_temperature, + per_token_weights=weights, + ) + loss_rev, loss_sft_rev, loss_conf_rev = compute_confidence_aware_loss( + logits_rev, + labels_rev, + lambda_conf=cfg.lambda_conf, + temperature=cfg.conf_temperature, + per_token_weights=weights_rev, + ) + + total_loss = loss + loss_rev + accelerator.backward(total_loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + global_step += 1 + + if global_step % cfg.logging_steps == 0 and accelerator.is_main_process: + logger.info( + "step=%d loss=%.4f sft=%.4f conf=%.4f lr=%.6g", + global_step, + total_loss.item(), + (loss_sft + loss_sft_rev).item(), + (loss_conf + loss_conf_rev).item(), + lr_scheduler.get_last_lr()[0], + ) + print( + f"step={global_step} loss={total_loss.item():.4f} " + f"sft={(loss_sft + loss_sft_rev).item():.4f} " + f"conf={(loss_conf + loss_conf_rev).item():.4f} " + f"lr={lr_scheduler.get_last_lr()[0]:.6g}" + ) + + if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}") + os.makedirs(save_dir, exist_ok=True) + accelerator.unwrap_model(model).save_pretrained(save_dir, save_function=accelerator.save) + tokenizer.save_pretrained(save_dir) + + if global_step >= cfg.max_train_steps: + break + + if global_step >= cfg.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + final_dir = os.path.join(cfg.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + accelerator.unwrap_model(model).save_pretrained(final_dir, save_function=accelerator.save) + tokenizer.save_pretrained(final_dir) + + logger.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 546fbe57be9e..deb63e3c5a98 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -338,10 +338,18 @@ "StableDiffusionMixin", ] ) + _import_structure["pipelines"].extend( + [ + "BlockRefinementPipeline", + "BlockRefinementPipelineOutput", + ] + ) _import_structure["quantizers"] = ["DiffusersQuantizer"] _import_structure["schedulers"].extend( [ "AmusedScheduler", + "BlockRefinementScheduler", + "BlockRefinementSchedulerOutput", "CMStochasticIterativeScheduler", "CogVideoXDDIMScheduler", "CogVideoXDPMScheduler", @@ -577,6 +585,8 @@ "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", + "LLaDA2Pipeline", + "LLaDA2PipelineOutput", "LongCatImageEditPipeline", "LongCatImagePipeline", "LTX2ConditionPipeline", @@ -1100,6 +1110,8 @@ AutoPipelineForText2Image, BlipDiffusionControlNetPipeline, BlipDiffusionPipeline, + BlockRefinementPipeline, + BlockRefinementPipelineOutput, CLIPImageProjection, ConsistencyModelPipeline, DanceDiffusionPipeline, @@ -1119,6 +1131,8 @@ from .quantizers import DiffusersQuantizer from .schedulers import ( AmusedScheduler, + BlockRefinementScheduler, + BlockRefinementSchedulerOutput, CMStochasticIterativeScheduler, CogVideoXDDIMScheduler, CogVideoXDPMScheduler, @@ -1333,6 +1347,8 @@ LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + LLaDA2Pipeline, + LLaDA2PipelineOutput, LongCatImageEditPipeline, LongCatImagePipeline, LTX2ConditionPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 8007035338b0..d12ce0939c0f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -47,6 +47,7 @@ "AutoPipelineForInpainting", "AutoPipelineForText2Image", ] + _import_structure["block_refinement"] = ["BlockRefinementPipeline", "BlockRefinementPipelineOutput"] _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] _import_structure["ddim"] = ["DDIMPipeline"] @@ -285,6 +286,7 @@ ] ) _import_structure["latte"] = ["LattePipeline"] + _import_structure["llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"] _import_structure["ltx"] = [ "LTXPipeline", "LTXImageToVideoPipeline", @@ -542,6 +544,7 @@ AutoPipelineForInpainting, AutoPipelineForText2Image, ) + from .block_refinement import BlockRefinementPipeline, BlockRefinementPipelineOutput from .consistency_models import ConsistencyModelPipeline from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline @@ -728,6 +731,7 @@ LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) + from .llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline from .ltx import ( LTXConditionPipeline, diff --git a/src/diffusers/pipelines/block_refinement/__init__.py b/src/diffusers/pipelines/block_refinement/__init__.py new file mode 100644 index 000000000000..1eec2ee97e81 --- /dev/null +++ b/src/diffusers/pipelines/block_refinement/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pipeline_block_refinement import BlockRefinementPipeline, BlockRefinementPipelineOutput + + +__all__ = ["BlockRefinementPipeline", "BlockRefinementPipelineOutput"] diff --git a/src/diffusers/pipelines/block_refinement/pipeline_block_refinement.py b/src/diffusers/pipelines/block_refinement/pipeline_block_refinement.py new file mode 100644 index 000000000000..219d5dc27422 --- /dev/null +++ b/src/diffusers/pipelines/block_refinement/pipeline_block_refinement.py @@ -0,0 +1,456 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import torch + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...schedulers import BlockRefinementScheduler +from ...utils import BaseOutput +from ..pipeline_utils import DiffusionPipeline, DiscreteDiffusionPipelineMixin + + +@dataclass +class BlockRefinementPipelineOutput(BaseOutput): + sequences: torch.LongTensor + texts: Optional[List[str]] = None + + +class BlockRefinementPipeline(DiffusionPipeline, DiscreteDiffusionPipelineMixin): + """ + Block-wise iterative refinement pipeline for token generation. + + This pipeline maintains a template sequence filled with a `mask_token_id` and refines it in blocks. In each + refinement step, it samples candidate tokens for the active block and commits a subset based on confidence. + + The model is expected to accept an additive attention mask of shape `[batch, 1, seq, seq]` (0 for allowed, `-inf` + for disallowed) and `position_ids`, and to return logits of shape `[batch, seq, vocab_size]`. + """ + + model: Any + scheduler: BlockRefinementScheduler + tokenizer: Any + + _callback_tensor_inputs = ["cur_x", "x0", "x0_p", "transfer_index", "confidence", "active_block"] + + def __init__( + self, + model: Any, + scheduler: BlockRefinementScheduler, + tokenizer: Optional[Any] = None, + ): + super().__init__() + self.register_modules(model=model, scheduler=scheduler, tokenizer=tokenizer) + + @property + def num_timesteps(self): + return self._num_timesteps + + def _model_forward_logits( + self, + input_ids: torch.LongTensor, + *, + attention_mask_4d: Optional[torch.Tensor], + attention_mask_2d: Optional[torch.Tensor], + position_ids: torch.LongTensor, + attention_mask_mode: str, + ) -> tuple[torch.Tensor, str]: + if attention_mask_mode not in {"auto", "4d", "2d", "none"}: + raise ValueError( + f"`attention_mask_mode` must be one of {{'auto','4d','2d','none'}}, got {attention_mask_mode!r}." + ) + + def _call(mask): + return self.model(input_ids, attention_mask=mask, position_ids=position_ids).logits + + if attention_mask_mode == "none": + return _call(None), "none" + if attention_mask_mode == "2d": + return _call(attention_mask_2d), "2d" + if attention_mask_mode == "4d": + return _call(attention_mask_4d), "4d" + + # auto: try 4d additive mask first, then fall back to 2d padding mask, then no mask. + try: + return _call(attention_mask_4d), "4d" + except (TypeError, ValueError, RuntimeError): + pass + try: + return _call(attention_mask_2d), "2d" + except (TypeError, ValueError, RuntimeError): + return _call(None), "none" + + def _build_block_attention_mask( + self, + *, + num_blocks: int, + block_length: int, + total_length: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + block_mask = torch.tril(torch.ones(num_blocks, num_blocks, device=device, dtype=torch.bool)) + attn = ( + block_mask.repeat_interleave(block_length, dim=0) + .repeat_interleave(block_length, dim=1) + .unsqueeze(0) + .unsqueeze(0) + ) + attn = attn[:, :, :total_length, :total_length] + return torch.where( + attn, + torch.zeros((), device=device, dtype=dtype), + torch.full((), float("-inf"), device=device, dtype=dtype), + ) + + def _encode_prompt( + self, + prompt: Optional[Union[str, List[str]]], + prompt_ids: Optional[torch.LongTensor], + *, + device: torch.device, + ) -> torch.LongTensor: + if prompt_ids is not None: + if prompt_ids.ndim == 1: + prompt_ids = prompt_ids.unsqueeze(0) + if prompt_ids.ndim != 2: + raise ValueError( + f"`prompt_ids` must have shape [prompt_len] or [batch, prompt_len], got {prompt_ids.shape}." + ) + if prompt_ids.dtype != torch.long: + raise ValueError(f"`prompt_ids` must be int64 token IDs, got dtype={prompt_ids.dtype}.") + return prompt_ids.to(device=device) + + if prompt is None: + return torch.zeros((1, 0), device=device, dtype=torch.long) + if getattr(self, "tokenizer", None) is None: + raise ValueError("`prompt` requires a tokenizer, but no tokenizer was provided to the pipeline.") + + encoded = self.tokenizer(prompt, return_tensors="pt", padding=True) + return encoded["input_ids"].to(device=device) + + def prepare_latents( + self, + batch_size: int, + total_length: int, + mask_token_id: int, + device: torch.device, + ) -> torch.LongTensor: + return torch.full((batch_size, total_length), int(mask_token_id), device=device, dtype=torch.long) + + def check_inputs( + self, + gen_length: int, + block_length: int, + steps: int, + minimal_topk: int, + threshold: float, + sampling_method: str, + callback_on_step_end: Optional[Union[Callable, PipelineCallback, MultiPipelineCallbacks]], + callback_on_step_end_tensor_inputs: Optional[List[str]], + ): + if gen_length <= 0: + raise ValueError(f"`gen_length` must be > 0, got {gen_length}.") + if block_length <= 0: + raise ValueError(f"`block_length` must be > 0, got {block_length}.") + if steps <= 0: + raise ValueError(f"`steps` must be > 0, got {steps}.") + if minimal_topk <= 0: + raise ValueError(f"`minimal_topk` must be > 0, got {minimal_topk}.") + if not (0.0 <= threshold <= 1.0) and not (threshold > 1.0): + raise ValueError(f"`threshold` must be in [0, 1] (or > 1 to force top-k commits), got {threshold}.") + if sampling_method not in {"auto", "greedy", "multinomial"}: + raise ValueError( + f"`sampling_method` must be one of {{'auto','greedy','multinomial'}}, got {sampling_method!r}." + ) + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + prompt_ids: Optional[torch.LongTensor] = None, + gen_length: int = 128, + block_length: int = 32, + steps: int = 32, + temperature: float = 0.0, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + sampling_method: str = "auto", + threshold: float = 0.95, + editing_threshold: Optional[float] = None, + max_post_steps: int = 0, + minimal_topk: int = 1, + eos_early_stop: bool = False, + eos_token_id: Optional[int] = None, + mask_token_id: Optional[int] = None, + attention_mask_mode: str = "auto", + generator: Optional[torch.Generator] = None, + return_text: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + ) -> BlockRefinementPipelineOutput: + """ + Generate tokens with block-wise refinement. + + Args: + prompt (`str` or `List[str]`, *optional*): + Prompt text to encode with the tokenizer. + prompt_ids (`torch.LongTensor`, *optional*): + Pre-tokenized prompt IDs with shape `[prompt_len]` or `[batch, prompt_len]`. + gen_length (`int`): + Number of tokens to generate. + block_length (`int`): + Block size for refinement. + steps (`int`): + Refinement steps per block. + temperature (`float`): + Sampling temperature. + top_p (`float`, *optional*): + Nucleus sampling cutoff. + top_k (`int`, *optional*): + Top-k sampling cutoff. + sampling_method (`str`): + Sampling method (`auto`, `greedy`, `multinomial`). + threshold (`float`): + Confidence threshold for committing tokens. + editing_threshold (`float`, *optional*): + Confidence threshold for editing already-committed (non-mask) tokens. When set, after all mask tokens + in a block are resolved, the pipeline continues refining: if the model predicts a different token with + confidence above this threshold, the existing token is replaced. Set to `None` or a negative value to + disable editing. Defaults to `None` (disabled). + max_post_steps (`int`): + Maximum number of additional refinement iterations after all mask tokens in a block are resolved. Only + used when `editing_threshold` is enabled. Defaults to `0` (no post-mask editing steps). + minimal_topk (`int`): + Minimum number of tokens to commit per step. + eos_early_stop (`bool`): + Whether to stop after committing EOS in a block. + eos_token_id (`int`, *optional*): + EOS token ID to use for early stopping. + mask_token_id (`int`, *optional*): + Mask token ID to use for the template. + attention_mask_mode (`str`): + Attention mask mode (`auto`, `4d`, `2d`, `none`). + generator (`torch.Generator`, *optional*): + RNG for sampling. + return_text (`bool`, *optional*, defaults to `True`): + Whether to decode sequences into text when a tokenizer is available. + callback_on_step_end (`Callable` or `PipelineCallback`, *optional*): + Callback executed after each refinement step with signature `callback_on_step_end(self, step: int, + timestep: int, callback_kwargs: Dict)`. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + Tensor keys to pass to the callback. Allowed keys: `cur_x`, `x0`, `x0_p`, `transfer_index`, + `confidence`, `active_block`. + """ + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["cur_x"] + + self.check_inputs( + gen_length=gen_length, + block_length=block_length, + steps=steps, + minimal_topk=minimal_topk, + threshold=threshold, + sampling_method=sampling_method, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + model_params = list(self.model.parameters()) if hasattr(self.model, "parameters") else [] + model_device = model_params[0].device if len(model_params) > 0 else torch.device("cpu") + + prompt_ids = self._encode_prompt(prompt, prompt_ids, device=model_device) + batch_size, prompt_length = prompt_ids.shape + + if eos_token_id is None: + eos_token_id = getattr(getattr(self, "tokenizer", None), "eos_token_id", None) + if mask_token_id is None: + mask_token_id = getattr(getattr(self, "tokenizer", None), "mask_token_id", None) + if mask_token_id is None: + raise ValueError("`mask_token_id` must be provided (or available on the tokenizer).") + + steps = min(int(steps), int(gen_length) // int(minimal_topk)) + + self.scheduler.set_timesteps(steps, device=model_device) + + num_blocks = (prompt_length + int(gen_length) + int(block_length) - 1) // int(block_length) + total_length = int(num_blocks) * int(block_length) + + dtype = getattr(self.model, "dtype", torch.float32) + attn_dtype = torch.bfloat16 if dtype in (torch.bfloat16, torch.float16) else torch.float32 + attn_mask_4d = self._build_block_attention_mask( + num_blocks=num_blocks, + block_length=block_length, + total_length=total_length, + device=model_device, + dtype=attn_dtype, + ) + attn_mask_2d_full = torch.ones((batch_size, total_length), device=model_device, dtype=torch.long) + position_ids = ( + torch.arange(total_length, device=model_device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) + ) + + x = self.prepare_latents(batch_size, total_length, int(mask_token_id), model_device) + if prompt_length > 0: + x[:, :prompt_length] = prompt_ids.to(device=model_device) + + prefill_blocks = prompt_length // int(block_length) + self._num_timesteps = int(steps) * max(int(num_blocks) - int(prefill_blocks), 0) + + finished = torch.zeros((batch_size,), device=model_device, dtype=torch.bool) + resolved_attention_mode: str = str(attention_mask_mode) + + use_multinomial = sampling_method == "multinomial" or (sampling_method == "auto" and float(temperature) != 0.0) + editing_enabled = editing_threshold is not None and editing_threshold >= 0.0 + global_step = 0 + + for num_block in range(int(prefill_blocks), int(num_blocks)): + current_window_end = (num_block + 1) * int(block_length) + cur_x = x[:, :current_window_end] + cur_attn_mask_4d = attn_mask_4d[:, :, :current_window_end, :current_window_end] + cur_attn_mask_2d = attn_mask_2d_full[:, :current_window_end] + cur_position_ids = position_ids[:, :current_window_end] + + # Identify which positions in the block are prompt (non-editable). + block_start_pos = num_block * int(block_length) + prompt_mask_in_block = torch.zeros(int(block_length), device=model_device, dtype=torch.bool) + if block_start_pos < prompt_length: + prompt_end_in_block = min(prompt_length - block_start_pos, int(block_length)) + prompt_mask_in_block[:prompt_end_in_block] = True + + post_steps = 0 + step_idx = 0 + while step_idx < int(steps) or (editing_enabled and post_steps <= int(max_post_steps)): + if finished.all(): + break + + block_tokens = cur_x[:, -int(block_length) :] + active_block = block_tokens == int(mask_token_id) + masks_remaining = active_block.any() + + if not masks_remaining and not editing_enabled: + break + if not masks_remaining: + post_steps += 1 + if post_steps > int(max_post_steps): + break + + logits, resolved_attention_mode = self._model_forward_logits( + cur_x, + attention_mask_4d=cur_attn_mask_4d, + attention_mask_2d=cur_attn_mask_2d, + position_ids=cur_position_ids, + attention_mask_mode=resolved_attention_mode, + ) + block_logits = logits[:, -int(block_length) :, :] + + x0, x0_p = self._sample_with_temperature_topk_topp( + block_logits, + temperature=float(temperature), + top_k=top_k, + top_p=top_p, + generator=generator, + use_multinomial=use_multinomial, + ) + + scheduler_output = self.scheduler.step( + sampled_tokens=x0, + sampled_probs=x0_p, + timestep=step_idx, + sample=block_tokens, + mask_token_id=int(mask_token_id), + threshold=float(threshold), + editing_threshold=editing_threshold, + minimal_topk=int(minimal_topk), + prompt_mask=prompt_mask_in_block, + generator=generator, + return_dict=True, + ) + + transfer_index = scheduler_output.transfer_index + editing_transfer_index = scheduler_output.editing_transfer_index + final_transfer = transfer_index | editing_transfer_index + + if final_transfer.any(): + cur_x[:, -int(block_length) :] = scheduler_output.prev_sample + + # Break if no masks remain and no edits were made. + if not masks_remaining and not editing_transfer_index.any(): + break + + if eos_early_stop and eos_token_id is not None: + for b in range(batch_size): + if finished[b]: + continue + eos_in_commits = (x0[b][final_transfer[b]] == int(eos_token_id)).any().item() + if not eos_in_commits: + continue + eos_pos = (cur_x[b] == int(eos_token_id)).nonzero(as_tuple=True) + if len(eos_pos[0]) == 0: + continue + eos_pos = int(eos_pos[0][0].item()) + if prompt_length >= eos_pos: + continue + if (cur_x[b, prompt_length:eos_pos] != int(mask_token_id)).all().item(): + finished[b] = True + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, global_step, step_idx, callback_kwargs) + cur_x = callback_outputs.pop("cur_x", cur_x) + + global_step += 1 + if masks_remaining: + step_idx += 1 + + x[:, :current_window_end] = cur_x + if eos_token_id is not None and (x[:, prompt_length:current_window_end] == int(eos_token_id)).any().item(): + if eos_early_stop: + break + + generated = x[:, : prompt_length + int(gen_length)] + sequences = generated[:, prompt_length:] + if eos_token_id is not None and batch_size == 1: + eos_positions = (sequences[0] == int(eos_token_id)).nonzero(as_tuple=True)[0] + if len(eos_positions) > 0: + sequences = sequences[:, : int(eos_positions[0].item()) + 1] + + texts = None + if return_text and getattr(self, "tokenizer", None) is not None: + texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + + return BlockRefinementPipelineOutput(sequences=sequences.to(device=model_device), texts=texts) diff --git a/src/diffusers/pipelines/llada2/__init__.py b/src/diffusers/pipelines/llada2/__init__.py new file mode 100644 index 000000000000..45a02e6851e2 --- /dev/null +++ b/src/diffusers/pipelines/llada2/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/llada2/pipeline_llada2.py b/src/diffusers/pipelines/llada2/pipeline_llada2.py new file mode 100644 index 000000000000..7b956469549e --- /dev/null +++ b/src/diffusers/pipelines/llada2/pipeline_llada2.py @@ -0,0 +1,188 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch + +from ...utils import BaseOutput, logging, replace_example_docstring +from ..block_refinement import BlockRefinementPipeline, BlockRefinementPipelineOutput + + +logger = logging.get_logger(__name__) + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from transformers import AutoModelForCausalLM, AutoTokenizer + >>> from diffusers import LLaDA2Pipeline + + >>> from diffusers import BlockRefinementScheduler + + >>> model_id = "inclusionAI/LLaDA2.0-mini" + >>> model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.bfloat16) + >>> tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + >>> model = model.to("cuda") + >>> scheduler = BlockRefinementScheduler() + + >>> pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer) + >>> output = pipe(prompt="What is the meaning of life?", gen_length=256) + >>> print(output.texts[0]) + ``` +""" + + +@dataclass +class LLaDA2PipelineOutput(BaseOutput): + sequences: torch.LongTensor + texts: Optional[List[str]] = None + + +class LLaDA2Pipeline(BlockRefinementPipeline): + r""" + Adapter pipeline for LLaDA2-style discrete diffusion generation. + + This pipeline subclasses [`BlockRefinementPipeline`] and reuses its sampling loop. It only adapts prompt + preparation (including chat templates) and output formatting. + """ + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + messages: Optional[List[Dict[str, str]]] = None, + input_ids: Optional[torch.LongTensor] = None, + use_chat_template: bool = True, + add_generation_prompt: bool = True, + gen_length: int = 2048, + block_length: int = 32, + steps: int = 32, + temperature: float = 0.0, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + sampling_method: str = "multinomial", + threshold: float = 0.95, + editing_threshold: Optional[float] = None, + max_post_steps: int = 0, + minimal_topk: int = 1, + eos_early_stop: bool = False, + eos_token_id: Optional[int] = None, + mask_token_id: Optional[int] = None, + attention_mask_mode: str = "4d", + generator: Optional[torch.Generator] = None, + return_text: bool = True, + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + ) -> Union[LLaDA2PipelineOutput, Tuple[torch.LongTensor, Optional[List[str]]]]: + """ + Generate text with block-wise refinement. + + Examples: + """ + prompt_ids = self._prepare_prompt_ids( + prompt=prompt, + messages=messages, + input_ids=input_ids, + use_chat_template=use_chat_template, + add_generation_prompt=add_generation_prompt, + ) + + output: BlockRefinementPipelineOutput = super().__call__( + prompt_ids=prompt_ids, + gen_length=gen_length, + block_length=block_length, + steps=steps, + temperature=temperature, + top_p=top_p, + top_k=top_k, + sampling_method=sampling_method, + threshold=threshold, + editing_threshold=editing_threshold, + max_post_steps=max_post_steps, + minimal_topk=minimal_topk, + eos_early_stop=eos_early_stop, + eos_token_id=eos_token_id, + mask_token_id=mask_token_id, + attention_mask_mode=attention_mask_mode, + generator=generator, + return_text=return_text, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + if not return_dict: + return output.sequences, output.texts + return LLaDA2PipelineOutput(sequences=output.sequences, texts=output.texts) + + def _prepare_prompt_ids( + self, + *, + prompt: Optional[Union[str, List[str]]], + messages: Optional[List[Dict[str, str]]], + input_ids: Optional[torch.LongTensor], + use_chat_template: bool, + add_generation_prompt: bool, + ) -> Optional[torch.LongTensor]: + if input_ids is not None: + return input_ids + + if self.tokenizer is None: + if prompt is None and messages is None: + return None + raise ValueError("Tokenizer is required to encode `prompt` or `messages`.") + + def _extract_input_ids(encoded): + if isinstance(encoded, dict) and "input_ids" in encoded: + return encoded["input_ids"] + if hasattr(encoded, "input_ids"): + return encoded.input_ids + return encoded + + if messages is not None: + encoded = self.tokenizer.apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + ) + return _extract_input_ids(encoded) + + if prompt is None: + return None + + if use_chat_template and getattr(self.tokenizer, "chat_template", None): + if isinstance(prompt, list): + raise ValueError("`prompt` must be a string when `use_chat_template=True`.") + encoded = self.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + ) + return _extract_input_ids(encoded) + + encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list)) + return _extract_input_ids(encoded) + + +__all__ = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"] diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d675f1de04a7..3b074832ed59 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -22,7 +22,7 @@ import types from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Union, get_args, get_origin, get_type_hints +from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin, get_type_hints import httpx import numpy as np @@ -2383,3 +2383,176 @@ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): else: self.vae.unfuse_qkv_projections() self.fusing_vae = False + + +class DiscreteDiffusionPipelineMixin: + """Shared utilities for discrete (token) diffusion pipelines. + + Provides SAR sampling techniques (top-p, top-k) and common helper methods for pipelines that operate on discrete + token sequences. + """ + + # --- SAR sampling utilities (static methods) --- + + @staticmethod + def _top_p_filtering(logits: "torch.Tensor", top_p: Optional[float]) -> "torch.Tensor": + """Nucleus (top-p) logit filtering.""" + if top_p is None or top_p >= 1.0: + return logits + if not (0.0 < top_p <= 1.0): + raise ValueError(f"`top_p` must be in (0, 1], got {top_p}.") + + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + sorted_probs = torch.softmax(sorted_logits, dim=-1) + cumulative_probs = sorted_probs.cumsum(dim=-1) + + sorted_indices_to_remove = cumulative_probs > float(top_p) + sorted_indices_to_remove[..., 0] = 0 + + sorted_logits = sorted_logits.masked_fill(sorted_indices_to_remove, torch.finfo(sorted_logits.dtype).min) + filtered = logits.scatter(-1, sorted_indices, sorted_logits) + return filtered + + @staticmethod + def _top_k_filtering(logits: "torch.Tensor", top_k: Optional[int]) -> "torch.Tensor": + """Top-k logit filtering.""" + if top_k is None or top_k <= 0: + return logits + if top_k >= logits.shape[-1]: + return logits + values, _ = torch.topk(logits, k=int(top_k), dim=-1) + min_keep = values[..., -1, None] + return logits.masked_fill(logits < min_keep, torch.finfo(logits.dtype).min) + + @staticmethod + def _sample_with_temperature_topk_topp( + logits: "torch.Tensor", + *, + temperature: float, + top_k: Optional[int], + top_p: Optional[float], + generator: Optional["torch.Generator"], + use_multinomial: bool, + ) -> "tuple[torch.LongTensor, torch.Tensor]": + """Sample tokens from logits with temperature scaling, top-k, and top-p.""" + vocab_size = logits.shape[-1] + flat_logits = logits.reshape(-1, vocab_size) + + filtered = DiscreteDiffusionPipelineMixin._top_k_filtering(flat_logits, top_k=top_k) + filtered = DiscreteDiffusionPipelineMixin._top_p_filtering(filtered, top_p=top_p) + + if temperature < 0: + raise ValueError(f"`temperature` must be >= 0, got {temperature}.") + + scaled = filtered + if temperature > 0.0 and temperature != 1.0: + scaled = filtered / float(temperature) + + probs = torch.softmax(scaled.float(), dim=-1) + if use_multinomial: + token = torch.multinomial(probs, num_samples=1, generator=generator) + else: + token = scaled.argmax(dim=-1, keepdim=True) + token_prob = torch.gather(probs, -1, token) + + return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1]) + + # --- Token/prefix utilities (instance methods) --- + + def _resolve_start_token_id(self) -> Optional[int]: + """Resolve BOS or CLS token ID from self.tokenizer.""" + tok = getattr(self, "tokenizer", None) + if tok is None: + return None + for attr in ("bos_token_id", "cls_token_id"): + token_id = getattr(tok, attr, None) + if token_id is not None: + return int(token_id) + return None + + def _normalize_prefix_ids( + self, prefix_ids: "torch.LongTensor", batch_size: int, device: "torch.device" + ) -> "torch.LongTensor": + """Validate shape/dtype and broadcast prefix token IDs.""" + if prefix_ids.ndim == 1: + prefix_ids = prefix_ids.unsqueeze(0) + if prefix_ids.ndim != 2: + raise ValueError( + f"`prefix_ids` must have shape [prefix_len] or [batch, prefix_len], got {prefix_ids.shape}." + ) + if prefix_ids.shape[0] not in (1, batch_size): + raise ValueError( + f"`prefix_ids` batch dim must be 1 or batch_size={batch_size}, got {prefix_ids.shape[0]}." + ) + if prefix_ids.dtype != torch.long: + raise ValueError(f"`prefix_ids` must be int64 token IDs, got dtype={prefix_ids.dtype}.") + prefix_ids = prefix_ids.to(device=device) + if prefix_ids.shape[0] == 1 and batch_size > 1: + prefix_ids = prefix_ids.expand(batch_size, -1) + return prefix_ids + + # --- Prompt encoding (instance method) --- + + def _prepare_input_ids( + self, + *, + prompt: Optional[Union[str, List[str]]], + messages: Optional[List[Dict[str, str]]], + input_ids: Optional["torch.LongTensor"], + use_chat_template: bool, + add_generation_prompt: bool, + chat_template_kwargs: Optional[Dict[str, object]], + ) -> "torch.LongTensor": + """Convert prompt/messages/input_ids to a [batch, seq] LongTensor.""" + if input_ids is not None: + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + if input_ids.ndim != 2: + raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.") + if input_ids.dtype != torch.long: + raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.") + return input_ids + + if self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + + if messages is not None and prompt is not None: + raise ValueError("Provide either `prompt` or `messages`, not both.") + if messages is None and prompt is None: + raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.") + + chat_template_kwargs = chat_template_kwargs or {} + + def _extract_input_ids(encoded): + if isinstance(encoded, dict) and "input_ids" in encoded: + return encoded["input_ids"] + if hasattr(encoded, "input_ids"): + return encoded.input_ids + return encoded + + if messages is not None: + encoded = self.tokenizer.apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + **chat_template_kwargs, + ) + return _extract_input_ids(encoded) + + if use_chat_template and getattr(self.tokenizer, "chat_template", None): + if isinstance(prompt, list): + raise ValueError("`prompt` must be a string when `use_chat_template=True`.") + encoded = self.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + **chat_template_kwargs, + ) + return _extract_input_ids(encoded) + + encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list)) + return _extract_input_ids(encoded) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index c7101d1b0401..b1f75bed7dc5 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -40,6 +40,7 @@ else: _import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"] _import_structure["scheduling_amused"] = ["AmusedScheduler"] + _import_structure["scheduling_block_refinement"] = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"] _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] _import_structure["scheduling_ddim"] = ["DDIMScheduler"] @@ -145,6 +146,7 @@ else: from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler from .scheduling_amused import AmusedScheduler + from .scheduling_block_refinement import BlockRefinementScheduler, BlockRefinementSchedulerOutput from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler diff --git a/src/diffusers/schedulers/scheduling_block_refinement.py b/src/diffusers/schedulers/scheduling_block_refinement.py new file mode 100644 index 000000000000..18df74efba99 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_block_refinement.py @@ -0,0 +1,214 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class BlockRefinementSchedulerOutput(BaseOutput): + """ + Output class for block refinement scheduling. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch_size, block_length)`): + Updated block tokens after the current refinement step. + transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`): + Boolean mask indicating which tokens were committed (mask-filling). + editing_transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`): + Boolean mask indicating which tokens were edited (non-mask replacement). + sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`): + Sampled token IDs from the model logits. + sampled_probs (`torch.Tensor` of shape `(batch_size, block_length)`): + Probabilities of the sampled tokens. + """ + + prev_sample: torch.LongTensor + transfer_index: torch.BoolTensor + editing_transfer_index: torch.BoolTensor + sampled_tokens: torch.LongTensor + sampled_probs: torch.Tensor + + +class BlockRefinementScheduler(SchedulerMixin, ConfigMixin): + """ + Scheduler for block-wise iterative refinement (commit-by-confidence). + + At each step, the scheduler samples candidate tokens and commits those with the highest confidence. The number of + tokens to commit per step is determined by evenly distributing the block length across the number of refinement + steps. + + Optionally supports editing: after all mask tokens are resolved, tokens can be replaced if the model predicts a + different token with confidence above `editing_threshold`. + """ + + order = 1 + + @register_to_config + def __init__( + self, + block_length: int = 32, + num_inference_steps: int = 32, + threshold: float = 0.95, + editing_threshold: Optional[float] = None, + minimal_topk: int = 1, + ): + self.num_inference_steps = int(num_inference_steps) + self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, dtype=torch.long) + self._transfer_schedule: Optional[torch.LongTensor] = None + + def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = int(num_inference_steps) + self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long) + self._transfer_schedule = self.get_num_transfer_tokens( + int(self.config.block_length), self.num_inference_steps + ).to(device=device if device is not None else "cpu") + + def get_num_transfer_tokens(self, block_length: int, num_inference_steps: int) -> torch.LongTensor: + """Evenly distribute `block_length` token commits across `num_inference_steps` steps.""" + if num_inference_steps <= 0: + return torch.zeros((0,), dtype=torch.long) + base = int(block_length) // int(num_inference_steps) + remainder = int(block_length) % int(num_inference_steps) + out = torch.full((int(num_inference_steps),), base, dtype=torch.long) + out[:remainder] += 1 + return out + + def step( + self, + sampled_tokens: torch.LongTensor, + sampled_probs: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.LongTensor, + *, + mask_token_id: int, + threshold: Optional[float] = None, + editing_threshold: Optional[float] = None, + minimal_topk: Optional[int] = None, + prompt_mask: Optional[torch.BoolTensor] = None, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[ + BlockRefinementSchedulerOutput, + Tuple[torch.LongTensor, torch.BoolTensor, torch.BoolTensor, torch.LongTensor, torch.Tensor], + ]: + """ + Perform a single refinement step: commit confident tokens and optionally edit existing ones. + + Args: + sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`): + Candidate token IDs sampled from model logits. + sampled_probs (`torch.Tensor` of shape `(batch_size, block_length)`): + Confidence probabilities for the sampled tokens. + timestep (`int` or `torch.Tensor`): + Current step index within the block's refinement schedule. + sample (`torch.LongTensor` of shape `(batch_size, block_length)`): + Current block token IDs (contains mask tokens for uncommitted positions). + mask_token_id (`int`): + Token ID used for masked positions. + threshold (`float`, *optional*): + Confidence threshold for committing tokens. Defaults to config value. + editing_threshold (`float`, *optional*): + Confidence threshold for editing non-mask tokens. Defaults to config value. + minimal_topk (`int`, *optional*): + Minimum tokens to commit per step. Defaults to config value. + prompt_mask (`torch.BoolTensor`, *optional*): + Boolean mask of shape `(block_length,)` where `True` marks prompt (non-editable) positions. + generator (`torch.Generator`, *optional*): + Unused, kept for API consistency. + return_dict (`bool`): + Whether to return a `BlockRefinementSchedulerOutput` or a tuple. + """ + if threshold is None: + threshold = float(self.config.threshold) + if editing_threshold is None: + editing_threshold = self.config.editing_threshold + if minimal_topk is None: + minimal_topk = int(self.config.minimal_topk) + + batch_size, block_length = sample.shape + active_block = sample == int(mask_token_id) + masks_remaining = active_block.any() + + if isinstance(timestep, torch.Tensor): + step_index = int(timestep.item()) + else: + step_index = int(timestep) + + # --- Mask-filling transfer --- + transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool) + if masks_remaining and self._transfer_schedule is not None: + clamped_step = min(step_index, len(self._transfer_schedule) - 1) + num_to_transfer = int(self._transfer_schedule[clamped_step].item()) + + confidence = torch.where( + active_block, + sampled_probs.to(dtype=torch.float32), + torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32), + ) + + for b in range(batch_size): + high_conf = confidence[b] > float(threshold) + if high_conf.sum().item() >= num_to_transfer: + transfer_index[b] = high_conf + else: + k = min(num_to_transfer, int(active_block[b].sum().item())) + if k > 0: + _, idx = torch.topk(confidence[b], k=k) + transfer_index[b, idx] = True + + # --- Editing transfer (non-mask, non-prompt positions) --- + editing_enabled = editing_threshold is not None and editing_threshold >= 0.0 + editing_transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool) + if editing_enabled: + if prompt_mask is None: + prompt_mask = torch.zeros(block_length, device=sample.device, dtype=torch.bool) + editable = (~active_block) & (~prompt_mask.unsqueeze(0)) + editing_conf = torch.where( + editable, + sampled_probs.to(dtype=torch.float32), + torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32), + ) + high_conf_edit = editing_conf > float(editing_threshold) + token_changed = sampled_tokens != sample + editing_transfer_index = high_conf_edit & token_changed & editable + + # Apply transfers + final_transfer = transfer_index | editing_transfer_index + prev_sample = sample.clone() + if final_transfer.any(): + prev_sample[final_transfer] = sampled_tokens[final_transfer] + + if not return_dict: + return prev_sample, transfer_index, editing_transfer_index, sampled_tokens, sampled_probs + return BlockRefinementSchedulerOutput( + prev_sample=prev_sample, + transfer_index=transfer_index, + editing_transfer_index=editing_transfer_index, + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + ) + + +__all__ = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"] diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 6c07a30c2ccc..fbccd82180a5 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -7,10 +7,11 @@ import warnings from contextlib import contextmanager from functools import partial -from typing import Any, Iterable +from typing import Any, Iterable, Optional, Tuple import numpy as np import torch +import torch.nn.functional as F if getattr(torch, "distributed", None) is not None: @@ -109,6 +110,92 @@ def compute_snr(noise_scheduler, timesteps): return snr +def compute_confidence_aware_loss( + logits: torch.Tensor, + labels: torch.Tensor, + *, + lambda_conf: float = 0.0, + temperature: float = 1.0, + per_token_weights: Optional[torch.Tensor] = None, + ignore_index: int = -100, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes a confidence-aware training loss for token classification-style heads. + + This loss combines: + - `loss_sft`: standard supervised cross-entropy on all non-ignored labels. + - `loss_conf`: an entropy penalty applied only on tokens that are already predicted correctly. + + Args: + logits (`torch.Tensor`): Logits of shape `(..., vocab_size)`. + labels (`torch.Tensor`): Labels of shape `(...)`, matching `logits.shape[:-1]`. Values set to `ignore_index` + are excluded from both losses. + lambda_conf (`float`, *optional*, defaults to `0.0`): Weight for the confidence term. + temperature (`float`, *optional*, defaults to `1.0`): Temperature used for the entropy term only. Lower values + sharpen the distribution and change the strength of the confidence gradients. + per_token_weights (`torch.Tensor`, *optional*): Optional weights of shape `(...)` to reweight both losses per + token (e.g. schedule-aware weights). Tokens with weight `0` contribute nothing. + ignore_index (`int`, *optional*, defaults to `-100`): Ignore index for labels. + + Returns: + `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: `(loss, loss_sft, loss_conf)`. + """ + if logits.ndim < 2: + raise ValueError(f"`logits` must have at least 2 dims, got shape {tuple(logits.shape)}.") + if labels.shape != logits.shape[:-1]: + raise ValueError( + f"`labels` shape must match `logits.shape[:-1]`, got labels={tuple(labels.shape)} logits={tuple(logits.shape)}." + ) + if temperature <= 0: + raise ValueError(f"`temperature` must be > 0, got {temperature}.") + + valid = labels.ne(ignore_index) + if per_token_weights is None: + weights = torch.ones_like(labels, dtype=logits.dtype) + else: + if per_token_weights.shape != labels.shape: + raise ValueError( + f"`per_token_weights` shape must match `labels` shape, got {tuple(per_token_weights.shape)} != {tuple(labels.shape)}." + ) + weights = per_token_weights.to(dtype=logits.dtype) + + # Supervised CE (optionally weighted). + vocab_size = logits.shape[-1] + per_token_nll = F.cross_entropy( + logits.reshape(-1, vocab_size), + labels.reshape(-1), + reduction="none", + ignore_index=ignore_index, + ).reshape_as(labels) + + denom_sft = (weights * valid.to(weights.dtype)).sum().clamp_min(1) + loss_sft = (per_token_nll * weights * valid.to(per_token_nll.dtype)).sum() / denom_sft + + # Confidence loss: penalize entropy only where prediction is already correct. + if lambda_conf == 0.0: + loss_conf = torch.zeros((), device=logits.device, dtype=loss_sft.dtype) + return loss_sft, loss_sft, loss_conf + + with torch.no_grad(): + pred = logits.argmax(dim=-1) + correct = valid & pred.eq(labels) + + scaled_logits = logits.float() + if temperature != 1.0: + scaled_logits = scaled_logits / float(temperature) + + probs = torch.softmax(scaled_logits, dim=-1) + eps = torch.finfo(probs.dtype).tiny + log_probs = torch.log(probs.clamp_min(eps)) + entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype) + + denom_conf = (weights * correct.to(weights.dtype)).sum().clamp_min(1) + loss_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / denom_conf + + loss = loss_sft + float(lambda_conf) * loss_conf + return loss, loss_sft, loss_conf + + def resolve_interpolation_mode(interpolation_type: str): """ Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 3425cc8d2b61..e3423ed57119 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2233,6 +2233,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class BlockRefinementPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class BlockRefinementPipelineOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CLIPImageProjection(metaclass=DummyObject): _backends = ["torch"] @@ -2488,6 +2518,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class BlockRefinementScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class BlockRefinementSchedulerOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CMStochasticIterativeScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 730a788ed1b8..f4b886927650 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2207,6 +2207,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LLaDA2Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LLaDA2PipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LongCatImageEditPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/others/test_training.py b/tests/others/test_training.py index 2038a98a813e..d8e86984ef1e 100644 --- a/tests/others/test_training.py +++ b/tests/others/test_training.py @@ -18,7 +18,7 @@ import torch from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel -from diffusers.training_utils import set_seed +from diffusers.training_utils import compute_confidence_aware_loss, set_seed from ..testing_utils import slow @@ -85,3 +85,47 @@ def test_training_step_equality(self): self.assertTrue(torch.allclose(ddpm_noisy_images, ddim_noisy_images, atol=1e-5)) self.assertTrue(torch.allclose(ddpm_noise_pred, ddim_noise_pred, atol=1e-5)) + + def test_confidence_aware_loss(self): + logits = torch.tensor([[[5.0, 0.0], [0.0, 5.0]]]) + labels = torch.tensor([[0, 0]]) + weights = torch.tensor([[1.0, 2.0]]) + + loss, loss_sft, loss_conf = compute_confidence_aware_loss( + logits, labels, lambda_conf=0.0, per_token_weights=weights + ) + self.assertTrue(torch.allclose(loss, loss_sft)) + self.assertTrue(torch.allclose(loss_conf, torch.zeros_like(loss_conf))) + + lambda_conf = 0.25 + loss, loss_sft, loss_conf = compute_confidence_aware_loss( + logits, labels, lambda_conf=lambda_conf, per_token_weights=weights + ) + + # Manual expected values for the small 2-class case. + per_token_nll = torch.nn.functional.cross_entropy(logits.view(-1, 2), labels.view(-1), reduction="none").view( + 1, 2 + ) + expected_sft = (per_token_nll * weights).sum() / weights.sum() + + pred = logits.argmax(dim=-1) + correct = pred.eq(labels) + log_probs = torch.log_softmax(logits.float(), dim=-1) + probs = log_probs.exp() + entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype) + expected_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / ( + weights * correct.to(weights.dtype) + ).sum().clamp_min(1) + + expected = expected_sft + lambda_conf * expected_conf + self.assertTrue(torch.allclose(loss_sft, expected_sft)) + self.assertTrue(torch.allclose(loss_conf, expected_conf)) + self.assertTrue(torch.allclose(loss, expected)) + + # Temperature affects only the confidence term. + loss_t, loss_sft_t, loss_conf_t = compute_confidence_aware_loss( + logits, labels, lambda_conf=lambda_conf, temperature=0.5, per_token_weights=weights + ) + self.assertTrue(torch.allclose(loss_sft_t, expected_sft)) + self.assertFalse(torch.allclose(loss_conf_t, expected_conf)) + self.assertTrue(torch.allclose(loss_t, loss_sft_t + lambda_conf * loss_conf_t)) diff --git a/tests/pipelines/test_discrete_diffusion_mixin.py b/tests/pipelines/test_discrete_diffusion_mixin.py new file mode 100644 index 000000000000..affff621c0dd --- /dev/null +++ b/tests/pipelines/test_discrete_diffusion_mixin.py @@ -0,0 +1,274 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock + +import torch + +from diffusers.pipelines.pipeline_utils import DiscreteDiffusionPipelineMixin + + +class TestTopPFiltering(unittest.TestCase): + def test_top_p_filtering(self): + logits = torch.tensor([[1.0, 2.0, 3.0, 4.0]]) + filtered = DiscreteDiffusionPipelineMixin._top_p_filtering(logits, top_p=0.5) + # Only the top token(s) summing to <= 0.5 probability should remain; + # the rest should be -inf (or dtype min). + # Verify that at least one token survived + self.assertTrue((filtered > torch.finfo(filtered.dtype).min).any()) + # Verify that some tokens were filtered + self.assertTrue((filtered == torch.finfo(filtered.dtype).min).any()) + + def test_top_p_filtering_none(self): + logits = torch.tensor([[1.0, 2.0, 3.0]]) + result = DiscreteDiffusionPipelineMixin._top_p_filtering(logits, top_p=None) + self.assertTrue(torch.equal(result, logits)) + + def test_top_p_filtering_one(self): + logits = torch.tensor([[1.0, 2.0, 3.0]]) + result = DiscreteDiffusionPipelineMixin._top_p_filtering(logits, top_p=1.0) + self.assertTrue(torch.equal(result, logits)) + + +class TestTopKFiltering(unittest.TestCase): + def test_top_k_filtering(self): + logits = torch.tensor([[1.0, 4.0, 2.0, 3.0]]) + filtered = DiscreteDiffusionPipelineMixin._top_k_filtering(logits, top_k=2) + # Only the top-2 values (4.0 and 3.0) should survive + self.assertAlmostEqual(filtered[0, 1].item(), 4.0) + self.assertAlmostEqual(filtered[0, 3].item(), 3.0) + self.assertEqual(filtered[0, 0].item(), torch.finfo(filtered.dtype).min) + self.assertEqual(filtered[0, 2].item(), torch.finfo(filtered.dtype).min) + + def test_top_k_filtering_none(self): + logits = torch.tensor([[1.0, 2.0, 3.0]]) + result = DiscreteDiffusionPipelineMixin._top_k_filtering(logits, top_k=None) + self.assertTrue(torch.equal(result, logits)) + + def test_top_k_filtering_zero(self): + logits = torch.tensor([[1.0, 2.0, 3.0]]) + result = DiscreteDiffusionPipelineMixin._top_k_filtering(logits, top_k=0) + self.assertTrue(torch.equal(result, logits)) + + def test_top_k_filtering_large_k(self): + logits = torch.tensor([[1.0, 2.0, 3.0]]) + result = DiscreteDiffusionPipelineMixin._top_k_filtering(logits, top_k=100) + self.assertTrue(torch.equal(result, logits)) + + +class TestSampleWithTemperature(unittest.TestCase): + def test_greedy_sampling(self): + logits = torch.tensor([[1.0, 5.0, 2.0]]) + tokens, probs = DiscreteDiffusionPipelineMixin._sample_with_temperature_topk_topp( + logits, + temperature=0.0, + top_k=None, + top_p=None, + generator=None, + use_multinomial=False, + ) + self.assertEqual(tokens.item(), 1) # index of max logit (5.0) + self.assertEqual(tokens.shape, (1,)) + self.assertEqual(probs.shape, (1,)) + + def test_multinomial_sampling(self): + logits = torch.tensor([[0.0, 100.0, -100.0]]) + gen = torch.Generator().manual_seed(42) + tokens, probs = DiscreteDiffusionPipelineMixin._sample_with_temperature_topk_topp( + logits, + temperature=1.0, + top_k=None, + top_p=None, + generator=gen, + use_multinomial=True, + ) + # With such extreme logits, token should always be 1 + self.assertEqual(tokens.item(), 1) + + def test_temperature_scaling(self): + logits = torch.tensor([[1.0, 2.0, 3.0]]) + # With very low temperature, should pick the max + tokens, _ = DiscreteDiffusionPipelineMixin._sample_with_temperature_topk_topp( + logits, + temperature=0.01, + top_k=None, + top_p=None, + generator=None, + use_multinomial=False, + ) + self.assertEqual(tokens.item(), 2) # index of max logit (3.0) + + def test_negative_temperature_raises(self): + logits = torch.tensor([[1.0, 2.0]]) + with self.assertRaises(ValueError, msg="`temperature` must be >= 0"): + DiscreteDiffusionPipelineMixin._sample_with_temperature_topk_topp( + logits, + temperature=-1.0, + top_k=None, + top_p=None, + generator=None, + use_multinomial=False, + ) + + +class TestResolveStartTokenId(unittest.TestCase): + def _make_mixin(self, tokenizer=None): + obj = DiscreteDiffusionPipelineMixin() + obj.tokenizer = tokenizer + return obj + + def test_no_tokenizer(self): + mixin = self._make_mixin(tokenizer=None) + self.assertIsNone(mixin._resolve_start_token_id()) + + def test_bos_token_id(self): + tok = MagicMock() + tok.bos_token_id = 1 + tok.cls_token_id = None + mixin = self._make_mixin(tokenizer=tok) + self.assertEqual(mixin._resolve_start_token_id(), 1) + + def test_cls_token_id_fallback(self): + tok = MagicMock() + tok.bos_token_id = None + tok.cls_token_id = 101 + mixin = self._make_mixin(tokenizer=tok) + self.assertEqual(mixin._resolve_start_token_id(), 101) + + def test_no_token_ids(self): + tok = MagicMock(spec=[]) + mixin = self._make_mixin(tokenizer=tok) + self.assertIsNone(mixin._resolve_start_token_id()) + + +class TestNormalizePrefixIds(unittest.TestCase): + def _make_mixin(self): + return DiscreteDiffusionPipelineMixin() + + def test_1d_input(self): + mixin = self._make_mixin() + prefix = torch.tensor([10, 20, 30], dtype=torch.long) + result = mixin._normalize_prefix_ids(prefix, batch_size=1, device=torch.device("cpu")) + self.assertEqual(result.shape, (1, 3)) + + def test_broadcast(self): + mixin = self._make_mixin() + prefix = torch.tensor([[10, 20]], dtype=torch.long) + result = mixin._normalize_prefix_ids(prefix, batch_size=4, device=torch.device("cpu")) + self.assertEqual(result.shape, (4, 2)) + self.assertTrue(torch.equal(result[0], result[3])) + + def test_wrong_dtype_raises(self): + mixin = self._make_mixin() + prefix = torch.tensor([1.0, 2.0]) + with self.assertRaises(ValueError, msg="int64"): + mixin._normalize_prefix_ids(prefix, batch_size=1, device=torch.device("cpu")) + + def test_wrong_batch_dim_raises(self): + mixin = self._make_mixin() + prefix = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.long) + with self.assertRaises(ValueError, msg="batch dim"): + mixin._normalize_prefix_ids(prefix, batch_size=2, device=torch.device("cpu")) + + +class TestPrepareInputIds(unittest.TestCase): + def _make_mixin(self, tokenizer=None): + obj = DiscreteDiffusionPipelineMixin() + obj.tokenizer = tokenizer + return obj + + def test_from_tensor(self): + mixin = self._make_mixin() + ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + result = mixin._prepare_input_ids( + prompt=None, + messages=None, + input_ids=ids, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + self.assertTrue(torch.equal(result, ids)) + + def test_from_tensor_1d(self): + mixin = self._make_mixin() + ids = torch.tensor([1, 2, 3], dtype=torch.long) + result = mixin._prepare_input_ids( + prompt=None, + messages=None, + input_ids=ids, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + self.assertEqual(result.shape, (1, 3)) + + def test_from_prompt(self): + tok = MagicMock() + tok.chat_template = None + tok.return_value = {"input_ids": torch.tensor([[10, 20, 30]])} + mixin = self._make_mixin(tokenizer=tok) + result = mixin._prepare_input_ids( + prompt="hello", + messages=None, + input_ids=None, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + self.assertEqual(result.shape, (1, 3)) + tok.assert_called_once() + + def test_no_tokenizer_raises(self): + mixin = self._make_mixin(tokenizer=None) + with self.assertRaises(ValueError, msg="Tokenizer is required"): + mixin._prepare_input_ids( + prompt="hello", + messages=None, + input_ids=None, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + + def test_both_prompt_and_messages_raises(self): + tok = MagicMock() + mixin = self._make_mixin(tokenizer=tok) + with self.assertRaises(ValueError, msg="not both"): + mixin._prepare_input_ids( + prompt="hello", + messages=[{"role": "user", "content": "hi"}], + input_ids=None, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + + def test_neither_prompt_nor_messages_raises(self): + tok = MagicMock() + mixin = self._make_mixin(tokenizer=tok) + with self.assertRaises(ValueError, msg="Provide one of"): + mixin._prepare_input_ids( + prompt=None, + messages=None, + input_ids=None, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pipelines/test_pipeline_block_refinement.py b/tests/pipelines/test_pipeline_block_refinement.py new file mode 100644 index 000000000000..9e280630a247 --- /dev/null +++ b/tests/pipelines/test_pipeline_block_refinement.py @@ -0,0 +1,95 @@ +import unittest + +import torch + +from diffusers import BlockRefinementPipeline, BlockRefinementScheduler + + +class _DummyModelOutput: + def __init__(self, logits): + self.logits = logits + + +class _DummyCausalLM(torch.nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.vocab_size = int(vocab_size) + self.register_buffer("_device_anchor", torch.empty(0)) + + @property + def dtype(self): + return torch.float32 + + @property + def device(self): + return self._device_anchor.device + + def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs): + batch_size, seq_len = input_ids.shape + logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device, dtype=torch.float32) + + # Make confidence vary with token position so top-k commits are deterministic. + positions = torch.arange(seq_len, device=input_ids.device, dtype=torch.float32).view(1, seq_len, 1) + token_ids = (torch.arange(seq_len, device=input_ids.device) % (self.vocab_size - 2)).view(1, seq_len, 1) + logits.scatter_(2, token_ids.expand(batch_size, -1, -1), 1.0 + positions.expand(batch_size, -1, -1) * 0.1) + return _DummyModelOutput(logits=logits) + + +class _DummyCausalLM2DOnly(_DummyCausalLM): + def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs): + if attention_mask is not None and attention_mask.ndim != 2: + raise ValueError("2D attention_mask required") + return super().forward(input_ids, attention_mask=attention_mask, position_ids=position_ids, **kwargs) + + +class BlockRefinementPipelineTest(unittest.TestCase): + def test_pipeline_runs(self): + vocab_size = 32 + model = _DummyCausalLM(vocab_size=vocab_size) + scheduler = BlockRefinementScheduler() + pipe = BlockRefinementPipeline(model=model, scheduler=scheduler, tokenizer=None).to("cpu") + + prompt_ids = torch.tensor([[5, 6, 7, 8], [1, 2, 3, 4]], dtype=torch.long) + out = pipe( + prompt_ids=prompt_ids, + gen_length=24, + block_length=8, + steps=8, + temperature=0.0, + threshold=2.0, # force top-k commits + minimal_topk=1, + eos_early_stop=False, + mask_token_id=vocab_size - 1, + eos_token_id=None, + return_text=False, + ) + + self.assertEqual(out.sequences.shape, (2, 24)) + self.assertFalse((out.sequences == vocab_size - 1).any().item()) + + def test_pipeline_falls_back_to_2d_attention_mask(self): + vocab_size = 32 + model = _DummyCausalLM2DOnly(vocab_size=vocab_size) + scheduler = BlockRefinementScheduler() + pipe = BlockRefinementPipeline(model=model, scheduler=scheduler, tokenizer=None).to("cpu") + + out = pipe( + prompt_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + gen_length=16, + block_length=8, + steps=4, + temperature=0.0, + threshold=2.0, + minimal_topk=1, + eos_early_stop=False, + mask_token_id=vocab_size - 1, + eos_token_id=None, + attention_mask_mode="auto", + return_text=False, + ) + + self.assertEqual(out.sequences.shape, (1, 16)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/schedulers/test_scheduler_block_refinement.py b/tests/schedulers/test_scheduler_block_refinement.py new file mode 100644 index 000000000000..63e2cee55850 --- /dev/null +++ b/tests/schedulers/test_scheduler_block_refinement.py @@ -0,0 +1,284 @@ +import tempfile +import unittest + +import torch + +from diffusers import BlockRefinementScheduler + + +class BlockRefinementSchedulerTest(unittest.TestCase): + def get_scheduler(self, **kwargs): + config = { + "block_length": 32, + "num_inference_steps": 8, + "threshold": 0.95, + "editing_threshold": None, + "minimal_topk": 1, + } + config.update(kwargs) + return BlockRefinementScheduler(**config) + + def test_set_timesteps(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(8) + self.assertEqual(scheduler.num_inference_steps, 8) + self.assertEqual(len(scheduler.timesteps), 8) + # Timesteps should count down + self.assertEqual(scheduler.timesteps[0].item(), 7) + self.assertEqual(scheduler.timesteps[-1].item(), 0) + + def test_set_timesteps_invalid(self): + scheduler = self.get_scheduler() + with self.assertRaises(ValueError): + scheduler.set_timesteps(0) + + def test_get_num_transfer_tokens_even(self): + scheduler = self.get_scheduler() + schedule = scheduler.get_num_transfer_tokens(block_length=32, num_inference_steps=8) + self.assertEqual(schedule.sum().item(), 32) + self.assertEqual(len(schedule), 8) + # 32 / 8 = 4 each, no remainder + self.assertTrue((schedule == 4).all().item()) + + def test_get_num_transfer_tokens_remainder(self): + scheduler = self.get_scheduler() + schedule = scheduler.get_num_transfer_tokens(block_length=10, num_inference_steps=3) + self.assertEqual(schedule.sum().item(), 10) + self.assertEqual(len(schedule), 3) + # 10 / 3 = 3 base, 1 remainder -> [4, 3, 3] + self.assertEqual(schedule[0].item(), 4) + self.assertEqual(schedule[1].item(), 3) + self.assertEqual(schedule[2].item(), 3) + + def test_transfer_schedule_created_on_set_timesteps(self): + scheduler = self.get_scheduler(block_length=16) + scheduler.set_timesteps(4) + self.assertIsNotNone(scheduler._transfer_schedule) + self.assertEqual(scheduler._transfer_schedule.sum().item(), 16) + + def test_save_load_config_round_trip(self): + scheduler = self.get_scheduler(block_length=64, threshold=0.8, editing_threshold=0.5, minimal_topk=2) + with tempfile.TemporaryDirectory() as tmpdir: + scheduler.save_config(tmpdir) + loaded = BlockRefinementScheduler.from_pretrained(tmpdir) + + self.assertEqual(loaded.config.block_length, 64) + self.assertEqual(loaded.config.threshold, 0.8) + self.assertEqual(loaded.config.editing_threshold, 0.5) + self.assertEqual(loaded.config.minimal_topk, 2) + + def test_from_config(self): + scheduler = self.get_scheduler(block_length=16, threshold=0.7) + new_scheduler = BlockRefinementScheduler.from_config(scheduler.config) + self.assertEqual(new_scheduler.config.block_length, 16) + self.assertEqual(new_scheduler.config.threshold, 0.7) + + def test_step_commits_tokens(self): + """Verify that step() commits mask tokens based on confidence.""" + scheduler = self.get_scheduler(block_length=8) + scheduler.set_timesteps(2) + + batch_size, block_length = 1, 8 + mask_id = 99 + + # All positions are masked + sample = torch.full((batch_size, block_length), mask_id, dtype=torch.long) + sampled_tokens = torch.arange(block_length, dtype=torch.long).unsqueeze(0) + # Confidence decreasing: first tokens are most confident + sampled_probs = torch.tensor([[0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2]]) + + out = scheduler.step( + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + timestep=0, + sample=sample, + mask_token_id=mask_id, + threshold=0.95, + return_dict=True, + ) + + # With 8 tokens and 2 steps, first step should commit 4 tokens + committed = out.transfer_index[0].sum().item() + self.assertEqual(committed, 4) + # The 4 most confident (highest prob) should be committed + self.assertTrue(out.transfer_index[0, 0].item()) + self.assertTrue(out.transfer_index[0, 1].item()) + self.assertTrue(out.transfer_index[0, 2].item()) + self.assertTrue(out.transfer_index[0, 3].item()) + + def test_step_threshold_commits_all_above(self): + """When enough tokens exceed threshold, commit all of them (not just num_to_transfer).""" + scheduler = self.get_scheduler(block_length=8) + scheduler.set_timesteps(4) # 2 tokens per step + + batch_size, block_length = 1, 8 + mask_id = 99 + + sample = torch.full((batch_size, block_length), mask_id, dtype=torch.long) + sampled_tokens = torch.arange(block_length, dtype=torch.long).unsqueeze(0) + # 5 tokens above threshold of 0.5 + sampled_probs = torch.tensor([[0.9, 0.8, 0.7, 0.6, 0.55, 0.1, 0.1, 0.1]]) + + out = scheduler.step( + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + timestep=0, + sample=sample, + mask_token_id=mask_id, + threshold=0.5, + return_dict=True, + ) + + # All 5 above threshold should be committed (more than num_to_transfer=2) + committed = out.transfer_index[0].sum().item() + self.assertEqual(committed, 5) + + def test_step_no_editing_by_default(self): + """Without editing_threshold, no non-mask tokens should be changed.""" + scheduler = self.get_scheduler(block_length=4) + scheduler.set_timesteps(2) + + sample = torch.tensor([[10, 20, 99, 99]], dtype=torch.long) + sampled_tokens = torch.tensor([[50, 60, 70, 80]], dtype=torch.long) + sampled_probs = torch.tensor([[0.99, 0.99, 0.99, 0.99]]) + + out = scheduler.step( + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + timestep=0, + sample=sample, + mask_token_id=99, + editing_threshold=None, + return_dict=True, + ) + + # Non-mask positions should not be edited + self.assertFalse(out.editing_transfer_index.any().item()) + # Only mask positions should be committed + self.assertFalse(out.transfer_index[0, 0].item()) + self.assertFalse(out.transfer_index[0, 1].item()) + + def test_step_editing_replaces_tokens(self): + """With editing_threshold, non-mask tokens with high confidence and different prediction get replaced.""" + scheduler = self.get_scheduler(block_length=4) + scheduler.set_timesteps(2) + + sample = torch.tensor([[10, 20, 99, 99]], dtype=torch.long) + # Token 0: model predicts 50 (different from 10) with high confidence + # Token 1: model predicts 20 (same as current) — should NOT edit + sampled_tokens = torch.tensor([[50, 20, 70, 80]], dtype=torch.long) + sampled_probs = torch.tensor([[0.99, 0.99, 0.5, 0.5]]) + + out = scheduler.step( + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + timestep=0, + sample=sample, + mask_token_id=99, + editing_threshold=0.8, + return_dict=True, + ) + + # Token 0 should be edited (different prediction, high confidence) + self.assertTrue(out.editing_transfer_index[0, 0].item()) + # Token 1 should NOT be edited (same prediction) + self.assertFalse(out.editing_transfer_index[0, 1].item()) + # prev_sample should reflect the edit + self.assertEqual(out.prev_sample[0, 0].item(), 50) + + def test_step_prompt_mask_prevents_editing(self): + """Prompt positions should never be edited even with editing enabled.""" + scheduler = self.get_scheduler(block_length=4) + scheduler.set_timesteps(2) + + sample = torch.tensor([[10, 20, 99, 99]], dtype=torch.long) + sampled_tokens = torch.tensor([[50, 60, 70, 80]], dtype=torch.long) + sampled_probs = torch.tensor([[0.99, 0.99, 0.99, 0.99]]) + prompt_mask = torch.tensor([True, True, False, False]) + + out = scheduler.step( + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + timestep=0, + sample=sample, + mask_token_id=99, + editing_threshold=0.5, + prompt_mask=prompt_mask, + return_dict=True, + ) + + # Prompt positions should not be edited + self.assertFalse(out.editing_transfer_index[0, 0].item()) + self.assertFalse(out.editing_transfer_index[0, 1].item()) + + def test_step_return_tuple(self): + """Verify tuple output when return_dict=False.""" + scheduler = self.get_scheduler(block_length=4) + scheduler.set_timesteps(2) + + sample = torch.full((1, 4), 99, dtype=torch.long) + sampled_tokens = torch.arange(4, dtype=torch.long).unsqueeze(0) + sampled_probs = torch.ones(1, 4) + + result = scheduler.step( + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + timestep=0, + sample=sample, + mask_token_id=99, + return_dict=False, + ) + + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 5) + + def test_step_batched(self): + """Verify step works with batch_size > 1.""" + scheduler = self.get_scheduler(block_length=4) + scheduler.set_timesteps(2) + + batch_size = 3 + mask_id = 99 + sample = torch.full((batch_size, 4), mask_id, dtype=torch.long) + sampled_tokens = torch.arange(4, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) + sampled_probs = torch.rand(batch_size, 4) + + out = scheduler.step( + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + timestep=0, + sample=sample, + mask_token_id=mask_id, + return_dict=True, + ) + + self.assertEqual(out.prev_sample.shape, (batch_size, 4)) + self.assertEqual(out.transfer_index.shape, (batch_size, 4)) + + def test_step_output_shape_matches_input(self): + """All output tensors should match the input sample shape.""" + scheduler = self.get_scheduler(block_length=8) + scheduler.set_timesteps(4) + + sample = torch.full((2, 8), 99, dtype=torch.long) + sampled_tokens = torch.zeros_like(sample) + sampled_probs = torch.rand(2, 8) + + out = scheduler.step( + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + timestep=0, + sample=sample, + mask_token_id=99, + return_dict=True, + ) + + self.assertEqual(out.prev_sample.shape, sample.shape) + self.assertEqual(out.transfer_index.shape, sample.shape) + self.assertEqual(out.editing_transfer_index.shape, sample.shape) + self.assertEqual(out.sampled_tokens.shape, sample.shape) + self.assertEqual(out.sampled_probs.shape, sample.shape) + + +if __name__ == "__main__": + unittest.main()