diff --git a/gradio/gradio_infer_demo.py b/gradio/gradio_infer_demo.py index 5b0b225..95787b9 100644 --- a/gradio/gradio_infer_demo.py +++ b/gradio/gradio_infer_demo.py @@ -119,7 +119,7 @@ def update_task(hf_model_id: str) -> Tuple[gr.Dropdown, gr.Component]: def update_subcheckpoints(checkpoint_dir): """Get subdirectories for the selected checkpoint directory.""" if checkpoint_dir == "None": - return gr.Dropdown(choices=[], interactive=False, visible=False) + return gr.Dropdown(choices=["None"], value="None", interactive=False, visible=False) # Get the full path to the checkpoint directory full_checkpoint_path = os.path.join(checkpoint_rootdir, checkpoint_dir) @@ -138,7 +138,7 @@ def update_subcheckpoints(checkpoint_dir): if not subdirs: # If there are no subdirectories, hide the dropdown - return gr.Dropdown(choices=[], interactive=False, visible=False) + return gr.Dropdown(choices=["None"], value="None", interactive=False, visible=False) # Show dropdown with available subdirectories return gr.Dropdown( @@ -183,6 +183,7 @@ def load_model_and_generate( ) # Load LoRA weights if selected + unload_lora_checkpoint(pipeline) if lora_checkpoint != "None": progress(0.3, desc="Loading LoRA weights...") # Construct the full path to the specific checkpoint @@ -192,8 +193,6 @@ def load_model_and_generate( lora_path = lora_checkpoint logger.info(f"Loading LoRA weights from {lora_path}") load_lora_checkpoint(pipeline, lora_path) - else: - unload_lora_checkpoint(pipeline) # Generate content based on task progress(0.5, desc="Generating content...") @@ -300,7 +299,7 @@ def load_model_and_generate( guidance_scale = gr.Slider( minimum=1.0, maximum=15.0, - value=6.0, + value=5.0, step=0.1, label="Guidance Scale", info="Higher values increase prompt adherence", diff --git a/gradio/gradio_lora_demo.py b/gradio/gradio_lora_demo.py index 609df9e..95e45ed 100644 --- a/gradio/gradio_lora_demo.py +++ b/gradio/gradio_lora_demo.py @@ -11,7 +11,6 @@ from torchvision.io import write_video from utils import ( BaseTask, - flatten_dict, get_dataset_dirs, get_logger, get_lora_checkpoint_rootdir, @@ -24,6 +23,7 @@ import gradio as gr from cogkit import GenerationMode, guess_generation_mode +from cogkit.utils import flatten_dict # ======================= global state ==================== diff --git a/gradio/utils/__init__.py b/gradio/utils/__init__.py index 9b92bad..e6dbb7b 100644 --- a/gradio/utils/__init__.py +++ b/gradio/utils/__init__.py @@ -8,7 +8,7 @@ resolve_path, ) from .logging import get_logger -from .misc import flatten_dict, get_resolutions +from .misc import get_resolutions from .task import BaseTask __all__ = [ @@ -22,5 +22,4 @@ "resolve_path", "BaseTask", "get_resolutions", - "flatten_dict", ] diff --git a/gradio/utils/misc.py b/gradio/utils/misc.py index b548d29..f4c15a8 100644 --- a/gradio/utils/misc.py +++ b/gradio/utils/misc.py @@ -1,9 +1,7 @@ -from typing import Any, Dict, List - from cogkit import GenerationMode -def get_resolutions(task: GenerationMode) -> List[str]: +def get_resolutions(task: GenerationMode) -> list[str]: if task == GenerationMode.TextToImage: return [ "512x512", @@ -19,35 +17,3 @@ def get_resolutions(task: GenerationMode) -> List[str]: "49x480x720", "81x768x1360", ] - - -def flatten_dict(d: Dict[str, Any], ignore_none: bool = False) -> Dict[str, Any]: - """ - Flattens a nested dictionary into a single layer dictionary. - - Args: - d: The dictionary to flatten - ignore_none: If True, keys with None values will be omitted - - Returns: - A flattened dictionary - - Raises: - ValueError: If there are duplicate keys across nested dictionaries - """ - result = {} - - def _flatten(current_dict, result_dict): - for key, value in current_dict.items(): - if value is None and ignore_none: - continue - - if isinstance(value, dict): - _flatten(value, result_dict) - else: - if key in result_dict: - raise ValueError(f"Duplicate key '{key}' found in nested dictionary") - result_dict[key] = value - - _flatten(d, result) - return result diff --git a/pyproject.toml b/pyproject.toml index 12fb154..0ddad69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,8 @@ finetune = [ "datasets~=3.4", "deepspeed~=0.16.4", "av~=14.2.0", + "bitsandbytes~=0.45.4", + "tensorboard~=2.19", ] [project.urls] diff --git a/quickstart/scripts/train.py b/quickstart/scripts/train.py index 3433ec0..e899273 100644 --- a/quickstart/scripts/train.py +++ b/quickstart/scripts/train.py @@ -7,7 +7,7 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--model_name", type=str, required=True) parser.add_argument("--training_type", type=str, required=True) - parser.add_argument("--enable_packing", action="store_true") + parser.add_argument("--enable_packing", type=lambda x: x.lower() == "true") args, unknown = parser.parse_known_args() trainer_cls = get_model_cls(args.model_name, args.training_type, args.enable_packing) diff --git a/quickstart/scripts/train_ddp_i2v.sh b/quickstart/scripts/train_ddp_i2v.sh index 1ff0928..57824eb 100755 --- a/quickstart/scripts/train_ddp_i2v.sh +++ b/quickstart/scripts/train_ddp_i2v.sh @@ -21,11 +21,6 @@ OUTPUT_ARGS=( # Data Configuration DATA_ARGS=( --data_root "/path/to/data" - - # Note: - # for CogVideoX series models, number of training frames should be **8N+1** - # for CogVideoX1.5 series models, number of training frames should be **16N+1** - --train_resolution "81x768x1360" # (frames x height x width) ) # Training Configuration @@ -35,13 +30,18 @@ TRAIN_ARGS=( --batch_size 1 --gradient_accumulation_steps 1 --mixed_precision "bf16" # ["no", "fp16"] - --learning_rate 2e-5 + --learning_rate 5e-5 + + # Note: + # for CogVideoX series models, number of training frames should be **8N+1** + # for CogVideoX1.5 series models, number of training frames should be **16N+1** + --train_resolution "81x768x1360" # (frames x height x width) ) # System Configuration SYSTEM_ARGS=( --num_workers 8 - --pin_memory True + --pin_memory true --nccl_timeout 1800 ) diff --git a/quickstart/scripts/train_ddp_t2i.sh b/quickstart/scripts/train_ddp_t2i.sh index eb38055..6aae45f 100755 --- a/quickstart/scripts/train_ddp_t2i.sh +++ b/quickstart/scripts/train_ddp_t2i.sh @@ -21,10 +21,6 @@ OUTPUT_ARGS=( # Data Configuration DATA_ARGS=( --data_root "/path/to/data" - - # Note: - # For CogView4 series models, height and width should be **32N** (multiple of 32) - --train_resolution "1024x1024" # (height x width) ) # Training Configuration @@ -32,15 +28,31 @@ TRAIN_ARGS=( --seed 42 # random seed --train_epochs 1 # number of training epochs --batch_size 1 + --gradient_accumulation_steps 1 + + # Note: For CogView4 series models, height and width should be **32N** (multiple of 32) + --train_resolution "1024x1024" # (height x width) + + # When enable_packing is true, training will use the native image resolution + # (otherwise all images will be resized to train_resolution, which may distort the original aspect ratio). + # + # IMPORTANT: When changing enable_packing from true to false (or vice versa), + # make sure to clear the .cache directories in your data_root/train and data_root/test folders if they exist. + --enable_packing false + --mixed_precision "bf16" # ["no", "fp16"] - --learning_rate 2e-5 + --learning_rate 5e-5 + + # enable --low_vram will slow down validation speed and enable quantization during training + # Note: --low_vram currently does not support multi-GPU training + --low_vram false ) # System Configuration SYSTEM_ARGS=( --num_workers 8 - --pin_memory True + --pin_memory true --nccl_timeout 1800 ) diff --git a/quickstart/scripts/train_ddp_t2v.sh b/quickstart/scripts/train_ddp_t2v.sh index 2cd7aea..ca31e49 100755 --- a/quickstart/scripts/train_ddp_t2v.sh +++ b/quickstart/scripts/train_ddp_t2v.sh @@ -20,11 +20,6 @@ OUTPUT_ARGS=( # Data Configuration DATA_ARGS=( --data_root "/path/to/data" - - # Note: - # for CogVideoX series models, number of training frames should be **8N+1** - # for CogVideoX1.5 series models, number of training frames should be **16N+1** - --train_resolution "81x768x1360" # (frames x height x width) ) # Training Configuration @@ -34,13 +29,18 @@ TRAIN_ARGS=( --batch_size 1 --gradient_accumulation_steps 1 --mixed_precision "bf16" # ["no", "fp16"] Note: CogVideoX-2B only supports fp16 training - --learning_rate 2e-5 + --learning_rate 5e-5 + + # Note: + # for CogVideoX series models, number of training frames should be **8N+1** + # for CogVideoX1.5 series models, number of training frames should be **16N+1** + --train_resolution "81x768x1360" # (frames x height x width) ) # System Configuration SYSTEM_ARGS=( --num_workers 8 - --pin_memory True + --pin_memory true --nccl_timeout 1800 ) diff --git a/quickstart/scripts/train_zero_i2v.sh b/quickstart/scripts/train_zero_i2v.sh index 4ebdd89..bfa07b7 100755 --- a/quickstart/scripts/train_zero_i2v.sh +++ b/quickstart/scripts/train_zero_i2v.sh @@ -20,11 +20,6 @@ OUTPUT_ARGS=( # Data Configuration DATA_ARGS=( --data_root "/path/to/data" - - # Note: - # for CogVideoX series models, number of training frames should be **8N+1** - # for CogVideoX1.5 series models, number of training frames should be **16N+1** - --train_resolution "81x768x1360" # (frames x height x width) ) # Training Configuration @@ -32,19 +27,25 @@ TRAIN_ARGS=( --seed 42 # random seed --train_epochs 1 # number of training epochs - --learning_rate 2e-5 + --learning_rate 5e-5 ######### Please keep consistent with deepspeed config file ########## --batch_size 1 --gradient_accumulation_steps 1 --mixed_precision "bf16" # ["no", "fp16"] Note: CogVideoX-2B only supports fp16 training ######################################################################## + + # Note: + # for CogVideoX series models, number of training frames should be **8N+1** + # for CogVideoX1.5 series models, number of training frames should be **16N+1** + --train_resolution "81x768x1360" # (frames x height x width) + ) # System Configuration SYSTEM_ARGS=( --num_workers 8 - --pin_memory True + --pin_memory true --nccl_timeout 1800 ) diff --git a/quickstart/scripts/train_zero_t2i.sh b/quickstart/scripts/train_zero_t2i.sh index a83b2fb..878cd27 100755 --- a/quickstart/scripts/train_zero_t2i.sh +++ b/quickstart/scripts/train_zero_t2i.sh @@ -20,10 +20,6 @@ OUTPUT_ARGS=( # Data Configuration DATA_ARGS=( --data_root "/path/to/data" - - # Note: - # For CogView4 series models, height and width should be **32N** (multiple of 32) - --train_resolution "1024x1024" # (height x width) ) # Training Configuration @@ -31,7 +27,10 @@ TRAIN_ARGS=( --seed 42 # random seed --train_epochs 1 # number of training epochs - --learning_rate 2e-5 + --learning_rate 5e-5 + + # Note: For CogView4 series models, height and width should be **32N** (multiple of 32) + --train_resolution "1024x1024" # (height x width) ######### Please keep consistent with deepspeed config file ########## --batch_size 1 @@ -39,12 +38,19 @@ TRAIN_ARGS=( --mixed_precision "bf16" # ["no", "fp16"] Note: CogVideoX-2B only supports fp16 training ######################################################################## + # When enable_packing is true, training will use the native image resolution + # (otherwise all images will be resized to train_resolution, which may distort the original aspect ratio). + # + # IMPORTANT: When changing enable_packing from true to false (or vice versa), + # make sure to clear the .cache directories in your data_root/train and data_root/test folders if they exist. + --enable_packing false + ) # System Configuration SYSTEM_ARGS=( --num_workers 8 - --pin_memory True + --pin_memory true --nccl_timeout 1800 ) @@ -62,7 +68,7 @@ VALIDATION_ARGS=( ) # Combine all arguments and launch training -accelerate launch --config_file ../configs/accelerate_config.yaml train.py \ +accelerate launch --config_file ../configs/accelerate_config.yaml train.py\ "${MODEL_ARGS[@]}" \ "${OUTPUT_ARGS[@]}" \ "${DATA_ARGS[@]}" \ diff --git a/quickstart/scripts/train_zero_t2v.sh b/quickstart/scripts/train_zero_t2v.sh index dd630ae..516afc2 100755 --- a/quickstart/scripts/train_zero_t2v.sh +++ b/quickstart/scripts/train_zero_t2v.sh @@ -20,11 +20,6 @@ OUTPUT_ARGS=( # Data Configuration DATA_ARGS=( --data_root "/path/to/data" - - # Note: - # for CogVideoX series models, number of training frames should be **8N+1** - # for CogVideoX1.5 series models, number of training frames should be **16N+1** - --train_resolution "81x768x1360" # (frames x height x width) ) # Training Configuration @@ -32,19 +27,24 @@ TRAIN_ARGS=( --seed 42 # random seed --train_epochs 1 # number of training epochs - --learning_rate 2e-5 + --learning_rate 5e-5 ######### Please keep consistent with deepspeed config file ########## --batch_size 1 --gradient_accumulation_steps 1 --mixed_precision "bf16" # ["no", "fp16"] Note: CogVideoX-2B only supports fp16 training ######################################################################## + + # Note: + # for CogVideoX series models, number of training frames should be **8N+1** + # for CogVideoX1.5 series models, number of training frames should be **16N+1** + --train_resolution "81x768x1360" # (frames x height x width) ) # System Configuration SYSTEM_ARGS=( --num_workers 8 - --pin_memory True + --pin_memory true --nccl_timeout 1800 ) diff --git a/src/cogkit/__init__.py b/src/cogkit/__init__.py index 71fee95..a32a259 100644 --- a/src/cogkit/__init__.py +++ b/src/cogkit/__init__.py @@ -8,6 +8,9 @@ load_lora_checkpoint, load_pipeline, unload_lora_checkpoint, + inject_lora, + save_lora, + unload_lora, ) __all__ = [ @@ -18,4 +21,7 @@ "unload_lora_checkpoint", "guess_generation_mode", "GenerationMode", + "inject_lora", + "save_lora", + "unload_lora", ] diff --git a/src/cogkit/api/services/image_generation.py b/src/cogkit/api/services/image_generation.py index e86afe2..4e75265 100644 --- a/src/cogkit/api/services/image_generation.py +++ b/src/cogkit/api/services/image_generation.py @@ -24,7 +24,7 @@ def __init__(self, settings: APISettings) -> None: torch_dtype = torch.bfloat16 if settings.dtype == "bfloat16" else torch.float32 cogview4_pl = load_pipeline( model_id_or_path=settings.cogview4_path, - lora_model_id_or_path=None, + lora_model_id_or_path=settings.lora_dir, transformer_path=settings.cogview4_transformer_path, dtype=torch_dtype, ) @@ -66,7 +66,7 @@ def generate( f"Loading LORA weights from {adapter_name} and unload previous weights {self._current_lora[model]}" ) unload_lora_checkpoint(self._models[model]) - load_lora_checkpoint(self._models[model], lora_path, lora_scale) + load_lora_checkpoint(self._models[model], lora_path) else: _logger.info(f"Unloading LORA weights {self._current_lora[model]}") unload_lora_checkpoint(self._models[model]) diff --git a/src/cogkit/datasets/__init__.py b/src/cogkit/datasets/__init__.py index 72dc194..e440d41 100644 --- a/src/cogkit/datasets/__init__.py +++ b/src/cogkit/datasets/__init__.py @@ -3,14 +3,18 @@ from cogkit.datasets.i2v_dataset import BaseI2VDataset, I2VDatasetWithResize from cogkit.datasets.t2v_dataset import BaseT2VDataset, T2VDatasetWithResize -from cogkit.datasets.t2i_dataset import BaseT2IDataset, T2IDatasetWithResize, T2IDatasetWithPacking +from cogkit.datasets.t2i_dataset import ( + T2IDatasetWithFactorResize, + T2IDatasetWithResize, + T2IDatasetWithPacking, +) __all__ = [ "BaseI2VDataset", "I2VDatasetWithResize", "BaseT2VDataset", "T2VDatasetWithResize", - "BaseT2IDataset", + "T2IDatasetWithFactorResize", "T2IDatasetWithResize", "T2IDatasetWithPacking", ] diff --git a/src/cogkit/datasets/t2i_dataset.py b/src/cogkit/datasets/t2i_dataset.py index 7bbb56e..c67f149 100644 --- a/src/cogkit/datasets/t2i_dataset.py +++ b/src/cogkit/datasets/t2i_dataset.py @@ -1,20 +1,23 @@ -import torchvision.transforms as transforms +import math from pathlib import Path from typing import TYPE_CHECKING, Any, Tuple -from PIL import Image import torch +import torchvision.transforms as transforms from accelerate.logging import get_logger from datasets import load_dataset +from PIL import Image from torch.utils.data import Dataset from typing_extensions import override from cogkit.finetune.diffusion.constants import LOG_LEVEL, LOG_NAME from .utils import ( - preprocess_image_with_resize, - get_prompt_embedding, get_image_embedding, + get_prompt_embedding, + pil2tensor, + preprocess_image_with_resize, + calculate_resize_dimensions, ) if TYPE_CHECKING: @@ -63,7 +66,6 @@ def __init__( self.encode_image = trainer.encode_image self.trainer = trainer - self.to_tensor = transforms.ToTensor() self._image_transforms = transforms.Compose( [ transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0), @@ -122,8 +124,10 @@ def preprocess( Returns: - image(torch.Tensor) of shape [C, H, W] + + **Note**: The value of returned image tensor should be the float value in the range of 0 ~ 255(rather than 0 ~ 1). """ - return self.to_tensor(image) + return pil2tensor(image) def image_transform(self, image: torch.Tensor) -> torch.Tensor: """ @@ -171,6 +175,55 @@ def preprocess( return image +class T2IDatasetWithFactorResize(BaseT2IDataset): + """ + A dataset class that resizes images to dimensions that are multiples of a specified factor. + + If the image dimensions are not divisible by the factor, the image is resized + to the nearest larger dimensions that are divisible by the factor. + + Args: + factor (int): The factor that image dimensions should be divisible by + """ + + def __init__( + self, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.factor = self.trainer.IMAGE_FACTOR + + @override + def preprocess( + self, + image: Image.Image, + device: torch.device = torch.device("cpu"), + ) -> torch.Tensor: + """ + Preprocesses an image by resizing it to dimensions that are multiples of self.factor. + + Args: + image: PIL.Image.Image object + device: Device to load the data on + + Returns: + torch.Tensor: Processed image tensor of shape [C, H, W] + """ + # Get original dimensions + width, height = image.size + maxpixels = self.trainer.state.train_resolution[0] * self.trainer.state.train_resolution[1] + new_height, new_width = calculate_resize_dimensions(height, width, maxpixels) + + # Calculate nearest multiples of factor (rounding down) + new_height = math.floor(new_height / self.factor) * self.factor + new_width = math.floor(new_width / self.factor) * self.factor + + assert new_height > 0 and new_width > 0, "Have image with height or width <= self.factor" + + return preprocess_image_with_resize(image, new_height, new_width, device) + + class T2IDatasetWithPacking(Dataset): """ This dataset class packs multiple samples from a base Text-to-Image dataset. @@ -179,13 +232,15 @@ class T2IDatasetWithPacking(Dataset): def __init__( self, - base_dataset: BaseT2IDataset, + base_dataset: T2IDatasetWithFactorResize, *args, **kwargs, ) -> None: super().__init__(*args, **kwargs) - assert type(base_dataset) is BaseT2IDataset # should literally be a BaseT2IDataset + # base_dataset should be a T2IDatasetWithFactorResize + assert type(base_dataset) is T2IDatasetWithFactorResize + self.base_dataset = base_dataset def __getitem__(self, index: list[int]) -> dict[str, Any]: diff --git a/src/cogkit/datasets/utils.py b/src/cogkit/datasets/utils.py index ed1bd19..fb6e87e 100644 --- a/src/cogkit/datasets/utils.py +++ b/src/cogkit/datasets/utils.py @@ -1,5 +1,6 @@ import hashlib import logging +import math from pathlib import Path from typing import Callable @@ -67,6 +68,12 @@ def load_images_from_videos(videos_path: list[Path]) -> list[Path]: ########## preprocessors ########## +def pil2tensor(image: Image.Image) -> torch.Tensor: + image = image.convert("RGB") + image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float().contiguous() + return image + + def preprocess_image_with_resize( image: Image.Image, height: int, @@ -242,3 +249,31 @@ def get_image_embedding( ) return encoded_image + + +def calculate_resize_dimensions(height: int, width: int, max_pixels: int) -> tuple[int, int]: + """ + Calculate new dimensions for an image while maintaining aspect ratio and limiting total pixels. + + Args: + height (int): Original height of the image + width (int): Original width of the image + max_pixels (int): Maximum number of pixels allowed + + Returns: + Tuple[int, int]: New (width, height) dimensions + """ + current_pixels = width * height + + # If current pixel count is already below max, return original dimensions + if current_pixels <= max_pixels: + return height, width + + # Calculate scaling factor to maintain aspect ratio + scale = math.sqrt(max_pixels / current_pixels) + + # Calculate new dimensions + new_height = int(height * scale) + new_width = int(width * scale) + + return new_height, new_width diff --git a/src/cogkit/finetune/base/base_args.py b/src/cogkit/finetune/base/base_args.py index bd7fb97..07e048d 100644 --- a/src/cogkit/finetune/base/base_args.py +++ b/src/cogkit/finetune/base/base_args.py @@ -37,6 +37,7 @@ class BaseArgs(BaseModel): gradient_accumulation_steps: int = 1 mixed_precision: Literal["no", "fp16", "bf16"] + low_vram: bool = False learning_rate: float = 2e-5 optimizer: str = "adamw" @@ -47,8 +48,8 @@ class BaseArgs(BaseModel): weight_decay: float = 1e-4 max_grad_norm: float = 1.0 - lr_scheduler: str = "constant_with_warmup" - lr_warmup_steps: int = 100 + lr_scheduler: str = "linear" + lr_warmup_ratio: float = 0.01 lr_num_cycles: int = 1 lr_power: float = 1.0 @@ -67,6 +68,12 @@ class BaseArgs(BaseModel): do_validation: bool = False validation_steps: int | None # if set, should be a multiple of checkpointing_steps + @field_validator("low_vram") + def validate_low_vram(cls, v: bool, info: ValidationInfo) -> bool: + if v and info.data.get("training_type") != "lora": + raise ValueError("low_vram can only be True when training_type is 'lora'") + return v + @field_validator("validation_steps") def validate_validation_steps(cls, v: int | None, info: ValidationInfo) -> int | None: values = info.data @@ -114,18 +121,21 @@ def get_base_parser(cls): parser.add_argument("--max_grad_norm", type=float, default=1.0) # Learning rate scheduler - parser.add_argument("--lr_scheduler", type=str, default="constant_with_warmup") - parser.add_argument("--lr_warmup_steps", type=int, default=100) + parser.add_argument("--lr_scheduler", type=str, default="linear") + parser.add_argument("--lr_warmup_ratio", type=float, default=0.01) parser.add_argument("--lr_num_cycles", type=int, default=1) parser.add_argument("--lr_power", type=float, default=1.0) # Data loading parser.add_argument("--num_workers", type=int, default=8) - parser.add_argument("--pin_memory", type=bool, default=True) + parser.add_argument("--pin_memory", type=lambda x: x.lower() == "true", default=True) # Model configuration parser.add_argument("--mixed_precision", type=str, default="no") - parser.add_argument("--gradient_checkpointing", type=bool, default=True) + parser.add_argument("--low_vram", type=lambda x: x.lower() == "true", default=False) + parser.add_argument( + "--gradient_checkpointing", type=lambda x: x.lower() == "true", default=True + ) parser.add_argument("--nccl_timeout", type=int, default=1800) # LoRA parameters diff --git a/src/cogkit/finetune/base/base_state.py b/src/cogkit/finetune/base/base_state.py index 2b791c2..a3307b6 100644 --- a/src/cogkit/finetune/base/base_state.py +++ b/src/cogkit/finetune/base/base_state.py @@ -8,7 +8,6 @@ class BaseState(BaseModel): weight_dtype: torch.dtype = torch.float32 # dtype for mixed precision training num_trainable_parameters: int = 0 - overwrote_max_train_steps: bool = False num_update_steps_per_epoch: int = 0 total_batch_size_count: int = 0 diff --git a/src/cogkit/finetune/base/base_trainer.py b/src/cogkit/finetune/base/base_trainer.py index 1410766..ea32c6a 100644 --- a/src/cogkit/finetune/base/base_trainer.py +++ b/src/cogkit/finetune/base/base_trainer.py @@ -20,24 +20,21 @@ set_seed, ) from diffusers.optimization import get_scheduler -from peft import ( - LoraConfig, - get_peft_model_state_dict, - set_peft_model_state_dict, -) from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from cogkit.finetune.base import BaseArgs, BaseComponents, BaseState +from cogkit.utils.lora import inject_lora, save_lora from ..utils import ( cast_training_params, free_memory, - get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from, get_memory_statistics, get_optimizer, unwrap_model, + find_files, + delete_files, ) _DTYPE_MAP = { @@ -63,24 +60,23 @@ class BaseTrainer(ABC): def __init__(self) -> None: self.logger = get_logger(self.LOG_NAME, self.LOG_LEVEL) - - self.args = self._init_args() - self.components = self.load_components() - self.state = self._init_state() - self.accelerator: Accelerator = None self.train_dataset: Dataset = None self.test_dataset: Dataset = None self.train_data_loader: DataLoader = None self.test_data_loader: DataLoader = None - self.optimizer = None self.lr_scheduler = None + self.args = self._init_args() + self.state = self._init_state() + self._init_distributed() self._init_logging() self._init_directories() + self.components = self.load_components() + self.state.using_deepspeed = self.accelerator.state.deepspeed_plugin is not None def _init_distributed(self): @@ -109,6 +105,12 @@ def _init_distributed(self): self.accelerator = accelerator + tracker_name = self.args.tracker_name + self.accelerator.init_trackers( + project_name=tracker_name, + init_kwargs={"wandb": {"name": self.args.output_dir.name}}, + ) + if self.args.seed is not None: set_seed(self.args.seed) @@ -168,25 +170,16 @@ def prepare_trainable_parameters(self) -> None: component.requires_grad_(False) if self.args.training_type == "lora": - transformer_lora_config = LoraConfig( - r=self.args.rank, - lora_alpha=self.args.lora_alpha, - init_lora_weights=True, - target_modules=self.args.target_modules, - ) - self.components.transformer.add_adapter(transformer_lora_config) - self.prepare_saving_loading_hooks(transformer_lora_config) - - # Load components needed for training to GPU (except transformer), and cast them to the specified data type - ignore_list = ["transformer"] + self.UNLOAD_LIST - self.move_components_to_device(dtype=weight_dtype, ignore_list=ignore_list) + # Initialize LoRA weights + inject_lora(self.components.transformer, lora_dir_or_state_dict=None) + self.prepare_saving_loading_hooks() if self.args.gradient_checkpointing: self.components.transformer.enable_gradient_checkpointing() def prepare_optimizer(self) -> None: # Make sure the trainable params are in float32 - cast_training_params([self.components.transformer], dtype=torch.float32) + # cast_training_params([self.components.transformer], dtype=torch.float32) # For LoRA, we only want to train the LoRA weights # For SFT, we want to train all the parameters @@ -220,12 +213,12 @@ def prepare_optimizer(self) -> None: use_deepspeed=use_deepspeed_opt, ) + # Do not need to divide by num_gpus since acclerate will handle this after prepare lr_scheduler num_update_steps_per_epoch = math.ceil( len(self.train_data_loader) / self.args.gradient_accumulation_steps ) - if self.args.train_steps is None: - self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch - self.state.overwrote_max_train_steps = True + total_train_steps = self.args.train_epochs * num_update_steps_per_epoch + total_num_warmup_steps = max(int(total_train_steps * self.args.lr_warmup_ratio), 0) use_deepspeed_lr_scheduler = ( self.accelerator.state.deepspeed_plugin is not None @@ -238,15 +231,15 @@ def prepare_optimizer(self) -> None: lr_scheduler = DummyScheduler( name=self.args.lr_scheduler, optimizer=optimizer, - total_num_steps=self.args.train_steps, - num_warmup_steps=self.args.lr_warmup_steps, + total_num_steps=total_train_steps, + num_warmup_steps=total_num_warmup_steps, ) else: lr_scheduler = get_scheduler( name=self.args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=self.args.lr_warmup_steps, - num_training_steps=self.args.train_steps, + num_warmup_steps=total_num_warmup_steps, + num_training_steps=total_train_steps, num_cycles=self.args.lr_num_cycles, power=self.args.lr_power, ) @@ -255,6 +248,9 @@ def prepare_optimizer(self) -> None: self.lr_scheduler = lr_scheduler def prepare_for_training(self) -> None: + # cast training params to the specified data type (bf16) + cast_training_params(self.components.transformer, dtype=self.state.weight_dtype) + ( self.components.transformer, self.optimizer, @@ -267,24 +263,25 @@ def prepare_for_training(self) -> None: self.lr_scheduler, ) + # Load components needed for training to GPU (except transformer), and cast them to the specified data type + ignore_list = self.UNLOAD_LIST + self.move_components_to_device( + dtype=self.state.weight_dtype, device=self.accelerator.device, ignore_list=ignore_list + ) + if self.args.do_validation: assert self.test_data_loader is not None self.test_data_loader = self.accelerator.prepare_data_loader(self.test_data_loader) - # We need to recalculate our total training steps as the size of the training dataloader may have changed. + # We need to recalculate our total training steps as the size of the training dataloader may have changed in distributed training num_update_steps_per_epoch = math.ceil( len(self.train_data_loader) / self.args.gradient_accumulation_steps ) - if self.state.overwrote_max_train_steps: - self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch + self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch) self.state.num_update_steps_per_epoch = num_update_steps_per_epoch - def prepare_trackers(self) -> None: - tracker_name = self.args.tracker_name - self.accelerator.init_trackers(tracker_name, config=self.args.model_dump()) - def train(self) -> None: memory_statistics = get_memory_statistics(self.logger) self.logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") @@ -338,6 +335,7 @@ def train(self) -> None: self.state.generator = generator free_memory() + ckpt_path = None for epoch in range(first_epoch, self.args.train_epochs): self.logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})") @@ -377,25 +375,27 @@ def train(self) -> None: if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - self.maybe_save_checkpoint(global_step) + ckpt_path = self.maybe_save_checkpoint(global_step) - logs["loss"] = loss.detach().item() - logs["lr"] = self.lr_scheduler.get_last_lr()[0] - progress_bar.set_postfix(logs) + logs["loss"] = loss.detach().item() + logs["lr"] = self.lr_scheduler.get_last_lr()[0] + progress_bar.set_postfix(logs) - # Maybe run validation - should_run_validation = ( - self.args.do_validation and global_step % self.args.validation_steps == 0 - ) - if should_run_validation: - del loss - free_memory() - self.validate(global_step) + # Maybe run validation + should_run_validation = ( + self.args.do_validation + and global_step % self.args.validation_steps == 0 + and accelerator.sync_gradients + ) + if should_run_validation: + del loss + free_memory() + self.validate(global_step, ckpt_path=ckpt_path) - accelerator.log(logs, step=global_step) + accelerator.log(logs, step=global_step) - if global_step >= self.args.train_steps: - break + if global_step >= self.args.train_steps: + break memory_statistics = get_memory_statistics(self.logger) self.logger.info( @@ -403,10 +403,10 @@ def train(self) -> None: ) accelerator.wait_for_everyone() - self.maybe_save_checkpoint(global_step, must_save=True) + ckpt_path = self.maybe_save_checkpoint(global_step, must_save=True) if self.args.do_validation: free_memory() - self.validate(global_step) + self.validate(global_step, ckpt_path=ckpt_path) del self.components free_memory() @@ -434,9 +434,6 @@ def fit(self) -> None: self.logger.info("Preparing for training...") self.prepare_for_training() - self.logger.info("Initializing trackers...") - self.prepare_trackers() - self.logger.info("Starting training...") self.train() @@ -470,7 +467,7 @@ def compute_loss(self, batch) -> torch.Tensor: raise NotImplementedError @abstractmethod - def validate(self, step: int) -> None: + def validate(self, step: int, ckpt_path: str | None = None) -> None: # validation logic defined here # during validation, additional modules in the pipeline may need to be moved to GPU memory raise NotImplementedError @@ -485,102 +482,76 @@ def get_training_dtype(self) -> torch.dtype: else: raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}") - def move_components_to_device(self, dtype, ignore_list: list[str] = []): + def move_components_to_device(self, dtype, device, ignore_list: list[str] = []): ignore_list = set(ignore_list) components = self.components.model_dump() for name, component in components.items(): - if not isinstance(component, type) and hasattr(component, "to"): - if name not in ignore_list: - setattr( - self.components, - name, - component.to(self.accelerator.device, dtype=dtype), - ) - - def move_components_to_cpu(self, unload_list: list[str] = []): - unload_list = set(unload_list) - components = self.components.model_dump() - for name, component in components.items(): - if not isinstance(component, type) and hasattr(component, "to"): - if name in unload_list: - setattr(self.components, name, component.to("cpu")) + if ( + not isinstance(component, type) + and hasattr(component, "to") + and name not in ignore_list + ): + setattr( + self.components, + name, + component.to(device, dtype=dtype), + ) - def prepare_saving_loading_hooks(self, transformer_lora_config): + def prepare_saving_loading_hooks(self): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): - if self.accelerator.is_main_process: - transformer_lora_layers_to_save = None - - for model in models: - if isinstance( - unwrap_model(self.accelerator, model), - type(unwrap_model(self.accelerator, self.components.transformer)), - ): - model = unwrap_model(self.accelerator, model) - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - else: - raise ValueError(f"Unexpected save model: {model.__class__}") - - # make sure to pop weight so that corresponding model is not saved again - if weights: - weights.pop() - - self.components.pipeline_cls.save_lora_weights( - output_dir, - transformer_lora_layers=transformer_lora_layers_to_save, - ) + assert self.accelerator.distributed_type != DistributedType.DEEPSPEED + + for model in models: + original_model = unwrap_model(self.accelerator, model) + original_transformer = unwrap_model(self.accelerator, self.components.transformer) + if isinstance(original_model, type(original_transformer)): + if self.accelerator.is_main_process: + save_lora(model, output_dir) + else: + raise ValueError(f"Unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() def load_model_hook(models, input_dir): - if self.accelerator.distributed_type != DistributedType.DEEPSPEED: - while len(models) > 0: - model = models.pop() - if isinstance( - unwrap_model(self.accelerator, model), - type(unwrap_model(self.accelerator, self.components.transformer)), - ): - transformer_ = unwrap_model(self.accelerator, model) - else: - raise ValueError( - f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}" - ) - else: - transformer_ = unwrap_model( - self.accelerator, self.components.transformer - ).__class__.from_pretrained(self.args.model_path, subfolder="transformer") - transformer_.add_adapter(transformer_lora_config) - - lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir) - transformer_state_dict = { - f"{k.replace('transformer.', '')}": v - for k, v in lora_state_dict.items() - if k.startswith("transformer.") - } - incompatible_keys = set_peft_model_state_dict( - transformer_, transformer_state_dict, adapter_name="default" - ) - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - self.logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) + assert self.accelerator.distributed_type != DistributedType.DEEPSPEED + + for model in models: + original_model = unwrap_model(self.accelerator, model) + original_transformer = unwrap_model(self.accelerator, self.components.transformer) + if isinstance(original_model, type(original_transformer)): + inject_lora(model, input_dir) + else: + raise ValueError(f"Unexpected save model: {model.__class__}") self.accelerator.register_save_state_pre_hook(save_model_hook) self.accelerator.register_load_state_pre_hook(load_model_hook) - def maybe_save_checkpoint(self, global_step: int, must_save: bool = False): - if ( - self.accelerator.distributed_type == DistributedType.DEEPSPEED - or self.accelerator.is_main_process - ): - if must_save or global_step % self.args.checkpointing_steps == 0: - # for training - save_path = get_intermediate_ckpt_path( - checkpointing_limit=self.args.checkpointing_limit, - step=global_step, - output_dir=self.args.output_dir, - logger=self.logger, - ) - self.accelerator.save_state(save_path, safe_serialization=True) + def maybe_save_checkpoint(self, global_step: int, must_save: bool = False) -> str | None: + if not (must_save or global_step % self.args.checkpointing_steps == 0): + return None + + checkpointing_limit = self.args.checkpointing_limit + output_dir = Path(self.args.output_dir) + logger = self.logger + + if checkpointing_limit is not None: + checkpoints = find_files(output_dir, prefix="checkpoint") + + # before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpointing_limit: + num_to_remove = len(checkpoints) - checkpointing_limit + 1 + checkpoints_to_remove = checkpoints[0:num_to_remove] + if self.accelerator.is_main_process: + delete_files(checkpoints_to_remove, logger) + + logger.info(f"Checkpointing at step {global_step}") + save_path = output_dir / f"checkpoint-{global_step}" + logger.info(f"Saving state to {save_path}") + + self.accelerator.save_state(save_path, safe_serialization=True) + + self.accelerator.wait_for_everyone() + return save_path diff --git a/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_i2v/lora_trainer.py b/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_i2v/lora_trainer.py index d71f40f..4311014 100644 --- a/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_i2v/lora_trainer.py +++ b/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_i2v/lora_trainer.py @@ -12,37 +12,70 @@ ) from diffusers.models.embeddings import get_3d_rotary_pos_embed from PIL import Image -from transformers import AutoTokenizer, T5EncoderModel +from transformers import AutoTokenizer, T5EncoderModel, BitsAndBytesConfig from typing_extensions import override from cogkit.finetune import register from cogkit.finetune.diffusion.schemas import DiffusionComponents from cogkit.finetune.diffusion.trainer import DiffusionTrainer from cogkit.finetune.utils import unwrap_model +from cogkit.utils import load_lora_checkpoint, unload_lora_checkpoint class CogVideoXI2VLoraTrainer(DiffusionTrainer): UNLOAD_LIST = ["text_encoder"] + NEGATIVE_PROMPT = "" @override def load_components(self) -> DiffusionComponents: + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + dtype = self.state.weight_dtype + components = DiffusionComponents() model_path = str(self.args.model_path) + ### pipeline components.pipeline_cls = CogVideoXImageToVideoPipeline + ### tokenizer components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer") + ### text encoder components.text_encoder = T5EncoderModel.from_pretrained( - model_path, subfolder="text_encoder" + model_path, + subfolder="text_encoder", + torch_dtype=dtype, ) - components.transformer = CogVideoXTransformer3DModel.from_pretrained( - model_path, subfolder="transformer" - ) + ### transformer + if not self.args.low_vram: + components.transformer = CogVideoXTransformer3DModel.from_pretrained( + model_path, + subfolder="transformer", + torch_dtype=dtype, + ) + else: + components.transformer = CogVideoXTransformer3DModel.from_pretrained( + model_path, + subfolder="transformer", + quantization_config=nf4_config, + device=self.accelerator.device, + torch_dtype=dtype, + ) - components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae") + ### vae + components.vae = AutoencoderKLCogVideoX.from_pretrained( + model_path, + subfolder="vae", + torch_dtype=dtype, + ) + ### scheduler components.scheduler = CogVideoXDPMScheduler.from_pretrained( model_path, subfolder="scheduler" ) @@ -50,14 +83,31 @@ def load_components(self) -> DiffusionComponents: return components @override - def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline: - pipe = CogVideoXImageToVideoPipeline( - tokenizer=self.components.tokenizer, - text_encoder=self.components.text_encoder, - vae=self.components.vae, - transformer=unwrap_model(self.accelerator, self.components.transformer), - scheduler=self.components.scheduler, - ) + def initialize_pipeline(self, ckpt_path: str | None = None) -> CogVideoXImageToVideoPipeline: + if not self.args.low_vram: + pipe = CogVideoXImageToVideoPipeline( + tokenizer=self.components.tokenizer, + text_encoder=self.components.text_encoder, + vae=self.components.vae, + transformer=unwrap_model(self.accelerator, self.components.transformer), + scheduler=self.components.scheduler, + ) + else: + assert self.args.training_type == "lora" + transformer = CogVideoXTransformer3DModel.from_pretrained( + str(self.args.model_path), + subfolder="transformer", + torch_dtype=self.state.weight_dtype, + ) + pipe = CogVideoXImageToVideoPipeline( + tokenizer=self.components.tokenizer, + text_encoder=self.components.text_encoder, + vae=self.components.vae, + transformer=transformer, + scheduler=self.components.scheduler, + ) + unload_lora_checkpoint(pipe) + load_lora_checkpoint(pipe, ckpt_path) return pipe @override @@ -88,6 +138,10 @@ def encode_text(self, prompt: str) -> torch.Tensor: assert prompt_embedding.ndim == 2 return prompt_embedding + @override + def get_negtive_prompt_embeds(self) -> torch.Tensor: + return self.encode_text(self.NEGATIVE_PROMPT) + @override def collate_fn(self, samples: list[dict[str, Any]]) -> dict[str, Any]: ret = { @@ -244,11 +298,7 @@ def compute_loss(self, batch) -> torch.Tensor: @override def validation_step( self, eval_data: dict[str, Any], pipe: CogVideoXImageToVideoPipeline - ) -> list[tuple[str, Image.Image | list[Image.Image]]]: - """ - Return the data that needs to be saved. For videos, the data format is List[PIL], - and for images, the data format is PIL - """ + ) -> dict[str, str | Image.Image | list[Image.Image]]: prompt, prompt_embedding, image, _ = ( eval_data["prompt"], eval_data["prompt_embedding"], @@ -261,10 +311,11 @@ def validation_step( height=self.state.train_resolution[1], width=self.state.train_resolution[2], prompt_embeds=prompt_embedding, + negative_prompt_embeds=self.get_negtive_prompt_embeds().unsqueeze(0), image=image, generator=self.state.generator, ).frames[0] - return [("prompt", prompt), ("image", image), ("video", video_generate)] + return {"text": prompt, "image": image, "video": video_generate} def prepare_rotary_positional_embeddings( self, diff --git a/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_t2v/lora_trainer.py b/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_t2v/lora_trainer.py index 1a21f05..cfa2594 100644 --- a/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_t2v/lora_trainer.py +++ b/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_t2v/lora_trainer.py @@ -12,37 +12,66 @@ ) from diffusers.models.embeddings import get_3d_rotary_pos_embed from PIL import Image -from transformers import AutoTokenizer, T5EncoderModel +from transformers import AutoTokenizer, T5EncoderModel, BitsAndBytesConfig from typing_extensions import override from cogkit.finetune import register from cogkit.finetune.diffusion.schemas import DiffusionComponents from cogkit.finetune.diffusion.trainer import DiffusionTrainer from cogkit.finetune.utils import unwrap_model +from cogkit.utils import load_lora_checkpoint, unload_lora_checkpoint class CogVideoXT2VLoraTrainer(DiffusionTrainer): UNLOAD_LIST = ["text_encoder", "vae"] + NEGATIVE_PROMPT = "" @override def load_components(self) -> DiffusionComponents: + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + dtype = self.state.weight_dtype + components = DiffusionComponents() model_path = str(self.args.model_path) + ### pipeline components.pipeline_cls = CogVideoXPipeline + ### tokenizer components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer") + ### text encoder components.text_encoder = T5EncoderModel.from_pretrained( - model_path, subfolder="text_encoder" + model_path, subfolder="text_encoder", torch_dtype=dtype ) - components.transformer = CogVideoXTransformer3DModel.from_pretrained( - model_path, subfolder="transformer" - ) + ### transformer + if not self.args.low_vram: + components.transformer = CogVideoXTransformer3DModel.from_pretrained( + model_path, + subfolder="transformer", + torch_dtype=dtype, + ) + else: + components.transformer = CogVideoXTransformer3DModel.from_pretrained( + model_path, + subfolder="transformer", + quantization_config=nf4_config, + device=self.accelerator.device, + torch_dtype=dtype, + ) - components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae") + ### vae + components.vae = AutoencoderKLCogVideoX.from_pretrained( + model_path, subfolder="vae", torch_dtype=dtype + ) + ### scheduler components.scheduler = CogVideoXDPMScheduler.from_pretrained( model_path, subfolder="scheduler" ) @@ -50,14 +79,31 @@ def load_components(self) -> DiffusionComponents: return components @override - def initialize_pipeline(self) -> CogVideoXPipeline: - pipe = CogVideoXPipeline( - tokenizer=self.components.tokenizer, - text_encoder=self.components.text_encoder, - vae=self.components.vae, - transformer=unwrap_model(self.accelerator, self.components.transformer), - scheduler=self.components.scheduler, - ) + def initialize_pipeline(self, ckpt_path: str | None = None) -> CogVideoXPipeline: + if not self.args.low_vram: + pipe = CogVideoXPipeline( + tokenizer=self.components.tokenizer, + text_encoder=self.components.text_encoder, + vae=self.components.vae, + transformer=unwrap_model(self.accelerator, self.components.transformer), + scheduler=self.components.scheduler, + ) + else: + assert self.args.training_type == "lora" + transformer = CogVideoXTransformer3DModel.from_pretrained( + str(self.args.model_path), + subfolder="transformer", + torch_dtype=self.state.weight_dtype, + ) + pipe = CogVideoXPipeline( + tokenizer=self.components.tokenizer, + text_encoder=self.components.text_encoder, + vae=self.components.vae, + transformer=transformer, + scheduler=self.components.scheduler, + ) + unload_lora_checkpoint(pipe) + load_lora_checkpoint(pipe, ckpt_path) return pipe @override @@ -88,6 +134,10 @@ def encode_text(self, prompt: str) -> torch.Tensor: assert prompt_embedding.ndim == 2 return prompt_embedding + @override + def get_negtive_prompt_embeds(self) -> torch.Tensor: + return self.encode_text(self.NEGATIVE_PROMPT) + @override def collate_fn(self, samples: list[dict[str, Any]]) -> dict[str, Any]: ret = {"prompt": [], "prompt_embedding": [], "encoded_videos": []} @@ -193,11 +243,7 @@ def compute_loss(self, batch) -> torch.Tensor: @override def validation_step( self, eval_data: dict[str, Any], pipe: CogVideoXPipeline - ) -> list[tuple[str, Image.Image | list[Image.Image]]]: - """ - Return the data that needs to be saved. For videos, the data format is List[PIL], - and for images, the data format is PIL - """ + ) -> dict[str, str | list[Image.Image]]: prompt = eval_data["prompt"] prompt_embedding = eval_data["prompt_embedding"] @@ -206,9 +252,10 @@ def validation_step( height=self.state.train_resolution[1], width=self.state.train_resolution[2], prompt_embeds=prompt_embedding, + negative_prompt_embeds=self.get_negtive_prompt_embeds().unsqueeze(0), generator=self.state.generator, ).frames[0] - return [("text", prompt), ("video", video_generate)] + return {"text": prompt, "video": video_generate} def prepare_rotary_positional_embeddings( self, diff --git a/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer.py b/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer.py index c41bbef..ac91bb2 100644 --- a/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer.py +++ b/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer.py @@ -4,88 +4,191 @@ from typing import Any, Tuple import torch -from diffusers import ( - AutoencoderKL, - CogView4Pipeline, - CogView4Transformer2DModel, - FlowMatchEulerDiscreteScheduler, -) from PIL import Image -from transformers import AutoTokenizer, GlmForCausalLM +from transformers import AutoTokenizer, BitsAndBytesConfig, GlmForCausalLM from typing_extensions import override from cogkit.finetune import register from cogkit.finetune.diffusion.schemas import DiffusionComponents from cogkit.finetune.diffusion.trainer import DiffusionTrainer -from cogkit.finetune.utils import unwrap_model +from cogkit.finetune.utils import ( + process_prompt_attention_mask, + unwrap_model, + replace_attn_processor, +) +from cogkit.utils import load_lora_checkpoint, unload_lora_checkpoint +from diffusers import ( + AutoencoderKL, + CogView4Pipeline, + CogView4Transformer2DModel, + FlowMatchEulerDiscreteScheduler, +) +from diffusers.models.transformers.transformer_cogview4 import CogView4TrainingAttnProcessor class Cogview4Trainer(DiffusionTrainer): UNLOAD_LIST = ["text_encoder", "vae"] MAX_TTOKEN_LENGTH = 224 + NEGATIVE_PROMPT = "" + TEXT_TOKEN_FACTOR = 16 @override def load_components(self) -> DiffusionComponents: + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + dtype = self.state.weight_dtype + components = DiffusionComponents() model_path = str(self.args.model_path) + ### pipeline components.pipeline_cls = CogView4Pipeline + ### tokenizer components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer") + ### text encoder components.text_encoder = GlmForCausalLM.from_pretrained( - model_path, subfolder="text_encoder" + model_path, + subfolder="text_encoder", + torch_dtype=dtype, ) - components.transformer = CogView4Transformer2DModel.from_pretrained( - model_path, subfolder="transformer" + ### transformer + if not self.args.low_vram: + components.transformer = CogView4Transformer2DModel.from_pretrained( + model_path, + subfolder="transformer", + torch_dtype=dtype, + ) + else: + components.transformer = CogView4Transformer2DModel.from_pretrained( + model_path, + subfolder="transformer", + torch_dtype=dtype, + quantization_config=nf4_config, + device=self.accelerator.device, + ) + replace_attn_processor(components.transformer, CogView4TrainingAttnProcessor()) + + ### vae + components.vae = AutoencoderKL.from_pretrained( + model_path, subfolder="vae", torch_dtype=dtype ) - components.vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae") - + ### scheduler components.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( model_path, subfolder="scheduler" ) return components @override - def initialize_pipeline(self) -> CogView4Pipeline: - pipe = CogView4Pipeline( - tokenizer=self.components.tokenizer, - text_encoder=self.components.text_encoder, - vae=self.components.vae, - transformer=unwrap_model(self.accelerator, self.components.transformer), - scheduler=self.components.scheduler, - ) + def initialize_pipeline(self, ckpt_path: str | None = None) -> CogView4Pipeline: + if not self.args.low_vram: + pipe = CogView4Pipeline( + tokenizer=self.components.tokenizer, + text_encoder=self.components.text_encoder, + vae=self.components.vae, + transformer=unwrap_model(self.accelerator, self.components.transformer), + scheduler=self.components.scheduler, + ) + else: + assert self.args.training_type == "lora" + # using bf16 model rather than quantized ones + transformer = CogView4Transformer2DModel.from_pretrained( + str(self.args.model_path), + subfolder="transformer", + torch_dtype=self.state.weight_dtype, + ) + replace_attn_processor(transformer, CogView4TrainingAttnProcessor()) + pipe = CogView4Pipeline( + tokenizer=self.components.tokenizer, + text_encoder=self.components.text_encoder, + vae=self.components.vae, + transformer=transformer, + scheduler=self.components.scheduler, + ) + unload_lora_checkpoint(pipe) + load_lora_checkpoint(pipe, ckpt_path) + return pipe @override def encode_text(self, prompt: str) -> torch.Tensor: + """ + Note: For the GLM text encoder, the number of tokens should be a multiple of 16. + """ prompt_token_ids = self.components.tokenizer( prompt, - padding="max_length", + padding=True, max_length=self.MAX_TTOKEN_LENGTH, truncation=True, add_special_tokens=True, return_tensors="pt", + pad_to_multiple_of=self.TEXT_TOKEN_FACTOR, ).input_ids + prompt_embedding = self.components.text_encoder( prompt_token_ids.to(self.accelerator.device), output_hidden_states=True ).hidden_states[-2][0] - # shape of prompt_embedding: [sequence length(self.MAX_TTOKEN_LENGTH), embedding dimension(4096)] + # shape of prompt_embedding: [sequence length, embedding dimension(4096)] return prompt_embedding + @override + def get_negtive_prompt_embeds(self) -> torch.Tensor: + return self.encode_text(self.NEGATIVE_PROMPT) + @override def encode_image(self, image: torch.Tensor) -> torch.Tensor: vae = self.components.vae - image = image.to(vae.device, dtype=vae.dtype) + image = image.to(self.accelerator.device, dtype=vae.dtype) latent_dist = vae.encode(image).latent_dist latent = latent_dist.sample() * vae.config.scaling_factor return latent @override def collate_fn(self, samples: list[dict[str, Any]]) -> dict[str, Any]: - ret = {"prompt": [], "prompt_embedding": [], "image": [], "encoded_image": []} + """ + Collate function that processes a batch of samples from the `T2IDatasetWithFactorResize` dataset. + + This function combines individual samples into a batch that can be processed by the model. + It handles prompt embeddings, images, and attention masks, ensuring proper formatting + for model training. + + This function is shared between training and validation dataloaders: + - During training: All fields (prompt, prompt_embedding, image, encoded_image) are provided + - During validation: Only 'prompt' and 'prompt_embedding' are provided, while 'image' and + 'encoded_image' will be None + + Args: + samples: A list of dictionaries, each representing a sample with keys: + - 'prompt': Text prompt string + - 'prompt_embedding': Encoded text prompt tensor + - 'image': Original image tensor (provided only during training) + - 'encoded_image': VAE-encoded latent representation (provided only during training) + + Returns: + A dictionary containing batch-processed data with keys: + - 'prompt': List of prompt strings + - 'prompt_embedding': Tensor of shape [batch_size, sequence_length, embedding_dim] + - 'image': List of image tensors (will be empty during validation) + - 'encoded_image': Tensor of shape [batch_size, channels, height, width] (None during validation) + - 'text_attn_mask': Tensor of shape [batch_size, sequence_length] for transformer attention + + Note: + This function assumes that all images in the batch have the same resolution. + """ + ret = { + "prompt": [], + "prompt_embedding": [], + "image": [], + "encoded_image": [], + "text_attn_mask": None, + } for sample in samples: prompt = sample.get("prompt", None) @@ -101,22 +204,21 @@ def collate_fn(self, samples: list[dict[str, Any]]) -> dict[str, Any]: if encoded_image is not None: ret["encoded_image"].append(encoded_image) - ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"]) - ret["encoded_image"] = torch.stack(ret["encoded_image"]) if ret["encoded_image"] else None + prompt_embedding, prompt_attention_mask = process_prompt_attention_mask( + self.components.tokenizer, + ret["prompt"], + ret["prompt_embedding"], + self.MAX_TTOKEN_LENGTH, + self.TEXT_TOKEN_FACTOR, + ) - prompts = [sample["prompt"] for sample in samples if "prompt" in sample] - attention_mask = self.components.tokenizer( - prompts, - padding="max_length", - max_length=self.MAX_TTOKEN_LENGTH, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ).attention_mask - ret["attention_mask"] = attention_mask + ret["prompt_embedding"] = prompt_embedding + ret["text_attn_mask"] = prompt_attention_mask + + ret["encoded_image"] = torch.stack(ret["encoded_image"]) if ret["encoded_image"] else None - # shape of prompt_embedding: [batch_size, max_sequence_length(self.MAX_TTOKEN_LENGTH), embedding_dim(4096)] - assert ret["attention_mask"].shape == ret["prompt_embedding"].shape[:2] + # shape of prompt_embedding: [batch_size, sequence_length, embedding_dim(4096)] + assert ret["text_attn_mask"].shape == ret["prompt_embedding"].shape[:2] return ret @@ -132,29 +234,17 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor: image_seq_len = ( (image_height // vae_scale_factor) * (image_width // vae_scale_factor) ) // (self.state.transformer_config.patch_size**2) + image_seq_len = torch.tensor([image_seq_len], device=self.accelerator.device) - text_attention_mask = batch["attention_mask"].float() + text_attn_mask = batch["text_attn_mask"] - # prepare timesteps - m = (image_seq_len / self.components.scheduler.config.base_image_seq_len) ** 0.5 - mu = ( - m * self.components.scheduler.config.max_shift - + self.components.scheduler.config.base_shift - ) - self.components.scheduler.set_timesteps( - self.components.scheduler.config.num_train_timesteps, - mu=mu, - device=self.accelerator.device, - ) - timestep = torch.randint( - 0, - self.components.scheduler.config.num_train_timesteps, - (1,), - device=self.accelerator.device, - ).long() + num_train_timesteps = self.components.scheduler.config.num_train_timesteps + sigmas = self.get_sigmas(batch_size, image_seq_len) + timestep = self.get_timestep(batch_size, num_train_timesteps) noise = torch.randn_like(latent) - model_input, model_label = self.add_noise(latent, noise, timestep[0]) + model_input, model_label = self.add_noise(latent, noise, timestep, sigmas) + original_size = torch.tensor( [[image_height, image_width] for _ in range(batch_size)], dtype=latent.dtype, @@ -177,7 +267,7 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor: target_size=target_size, crop_coords=crop_coords, return_dict=False, - attention_mask=text_attention_mask, + attention_kwargs={"text_attn_mask": text_attn_mask}, )[0] loss = torch.mean((noise_pred_cond - model_label) ** 2, dim=(1, 2, 3)) @@ -185,45 +275,73 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor: return loss + def get_sigmas(self, batch_size: int, vtoken_seq_len: torch.Tensor) -> torch.Tensor: + assert vtoken_seq_len.ndim == 1 + if vtoken_seq_len.size(0) == 1: + vtoken_seq_len = vtoken_seq_len.repeat(batch_size) + else: + assert vtoken_seq_len.size(0) == batch_size + + scheduler = self.components.scheduler + scheduler = self.components.scheduler + sigmas = torch.linspace( + scheduler.sigma_min, + scheduler.sigma_max, + scheduler.config.num_train_timesteps, + device=self.accelerator.device, + ) + m = (vtoken_seq_len / scheduler.config.base_image_seq_len) ** 0.5 + mu = m * scheduler.config.max_shift + scheduler.config.base_shift + mu = mu.unsqueeze(1) + sigmas = mu / (mu + (1 / sigmas - 1)) + sigmas = torch.cat([torch.zeros((batch_size, 1), device=sigmas.device), sigmas], dim=1) + return sigmas + + def get_timestep(self, batch_size: int, num_train_timesteps: int) -> torch.LongTensor: + return torch.randint( + 0, + num_train_timesteps, + (batch_size,), + device=self.accelerator.device, + ) + def add_noise( - self, latent: torch.Tensor, noise: torch.Tensor, timestep: torch.LongTensor + self, + latent: torch.Tensor, + noise: torch.Tensor, + timestep: torch.Tensor, + sigmas: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Add noise to the latent vector based on the timestep. - - Args: - latent (torch.Tensor): The latent vector to add noise to. - noise (torch.Tensor): The noise tensor to add. - timestep (torch.LongTensor): The current timestep. + assert latent.shape[0] == noise.shape[0] == timestep.shape[0] == sigmas.shape[0] + index = timestep + scale_factor = ( + torch.gather(sigmas, dim=1, index=index.unsqueeze(1)) + .squeeze(1) + .view(-1, 1, 1, 1) + .to(latent.device) + ) - Returns: - Tuple[torch.Tensor, torch.Tensor]: The noisy latent vector that will be input to the model and the model label. - """ - num_train_timesteps = self.components.scheduler.config.num_train_timesteps - # note: sigmas in scheduler is arranged in reversed order - scale_factor = self.components.scheduler.sigmas[num_train_timesteps - timestep] model_input = latent * (1 - scale_factor) + noise * scale_factor model_label = noise - latent return model_input, model_label @override def validation_step( - self, eval_data: dict[str, Any], pipe: CogView4Pipeline - ) -> list[tuple[str, Image.Image | list[Image.Image]]]: - """ - Return the data that needs to be saved. For images, the data format is PIL - """ + self, pipe: CogView4Pipeline, eval_data: dict[str, Any] + ) -> dict[str, str | Image.Image]: prompt = eval_data["prompt"] - _ = eval_data["prompt_embedding"] + prompt_embedding = eval_data["prompt_embedding"] image_generate = pipe( height=self.state.train_resolution[0], width=self.state.train_resolution[1], - prompt=prompt, - # prompt_embeds=prompt_embedding, + prompt_embeds=prompt_embedding, + negative_prompt_embeds=self.state.negative_prompt_embeds.unsqueeze( + 0 + ), # Add batch dimension generator=self.state.generator, ).images[0] - return [("text", prompt), ("image", image_generate)] + return {"text": prompt, "image": image_generate} register("cogview4-6b", "lora", Cogview4Trainer) diff --git a/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer_packing.py b/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer_packing.py index 05a4c88..76f1aef 100644 --- a/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer_packing.py +++ b/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer_packing.py @@ -7,35 +7,53 @@ from typing_extensions import override from cogkit.finetune import register -from cogkit.finetune.diffusion.trainer import DiffusionState from cogkit.finetune.utils import ( - expand_list, process_latent_attention_mask, process_prompt_attention_mask, ) +from cogkit.utils import expand_list +from diffusers.models.transformers.transformer_cogview4 import CogView4RotaryPosEmbed from .lora_trainer import Cogview4Trainer class Cogview4LoraPackingTrainer(Cogview4Trainer): + IMAGE_FACTOR = 32 # Size of image (height, width) to be trained should be a multiple of 32 + DOWNSAMPLER_FACTOR = 8 + PATCH_SIZE: int + ATTN_HEAD: int + ATTEN_DIM: int + ROPE_DIM: Tuple[int, int] + + max_vtoken_length: int + training_seq_length: int + rope: CogView4RotaryPosEmbed @override - def _init_state(self) -> DiffusionState: - patch_size = self.components.transformer.config.patch_size + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + transformer = self.components.transformer + self.PATCH_SIZE = transformer.config.patch_size + self.ATTN_HEAD = transformer.config.num_attention_heads + self.ATTEN_DIM = transformer.config.attention_head_dim + self.ROPE_DIM = transformer.config.rope_axes_dim + + patch_size = self.PATCH_SIZE height, width = self.args.train_resolution sample_height, sample_width = ( height // self.DOWNSAMPLER_FACTOR, width // self.DOWNSAMPLER_FACTOR, ) - max_vtoken_length = sample_height * sample_width // patch_size**2 - training_seq_length = max_vtoken_length + self.MAX_TTOKEN_LENGTH - - return DiffusionState( - weight_dtype=self.get_training_dtype(), - train_resolution=self.args.train_resolution, - max_vtoken_length=max_vtoken_length, - training_seq_length=training_seq_length, + self.max_vtoken_length = sample_height * sample_width // patch_size**2 + self.training_seq_length = self.max_vtoken_length + self.MAX_TTOKEN_LENGTH + self.state.training_seq_length = self.training_seq_length + + self.rope = CogView4RotaryPosEmbed( + dim=self.ATTEN_DIM, + patch_size=self.PATCH_SIZE, + rope_axes_dim=self.ROPE_DIM, ) @override @@ -47,185 +65,147 @@ def sample_to_length(self, sample: dict[str, Any]) -> int: assert prompt_embedding.ndim == 2 num_channels, latent_height, latent_width = image_latent.shape - patch_size = self.components.transformer.config.patch_size + patch_size = self.PATCH_SIZE paded_width = latent_width + latent_width % patch_size paded_height = latent_height + latent_height % patch_size latent_length = paded_height * paded_width // patch_size**2 - if latent_length > self.state.max_vtoken_length: + if latent_length > self.max_vtoken_length: raise ValueError( - f"latent_length {latent_length} is greater than max_vtoken_length {self.state.max_vtoken_length}, " - f"which means there is at least one sample in the batch has resolution greater than " - f"{self.args.train_resolution[0]}x{self.args.train_resolution[1]}" + f"latent_length {latent_length} is greater than max_vtoken_length {self.max_vtoken_length}" ) - assert ( - self.MAX_TTOKEN_LENGTH + self.state.max_vtoken_length == self.state.training_seq_length - ) - assert latent_length + prompt_embedding.shape[0] <= self.state.training_seq_length + assert self.MAX_TTOKEN_LENGTH + self.max_vtoken_length == self.training_seq_length + assert latent_length + prompt_embedding.shape[0] <= self.training_seq_length return latent_length + prompt_embedding.shape[0] @override def collate_fn_packing(self, samples: list[dict[str, list[Any]]]) -> dict[str, Any]: - """ - Note: This collate_fn is for the training dataloader. - For validation, you should use the collate_fn from Cogview4Trainer. + """Collate function for training dataloader with packing support. + + This function processes batches of samples from the `T2IDatasetWithPacking` dataset, + combining multiple samples into a single batch while maintaining proper attention masks + and positional embeddings. + + Args: + samples: List of dictionaries containing packed samples. Each sample contains: + - prompt: List of text prompts + - prompt_embedding: List of prompt embeddings + - encoded_image: List of encoded image latents + - image: List of original images + + Returns: + dict: A dictionary containing batched data with the following keys: + - prompt_embedding: Batched prompt embeddings + - encoded_image: Batched encoded image latents + - image_rotary_emb: Rotary embeddings for images + - attention_kwargs: Dictionary containing: + - batch_flag: Indices indicating which sample each item belongs to + - text_attn_mask: Attention mask for text embeddings + - latent_attn_mask: Attention mask for latent embeddings + - pixel_mask: Mask for valid pixel regions + - original_size: Original dimensions of the images + + Note: + This function is specifically used for the training dataloader. + For validation, the collate_fn from Cogview4Trainer should be used instead. """ batched_data = { "prompt_embedding": None, "encoded_image": None, - "attention_mask": { + "image_rotary_emb": None, + "attention_kwargs": { "batch_flag": None, - "text_embedding_attn_mask": None, - "latent_embedding_attn_mask": None, + "text_attn_mask": None, + "latent_attn_mask": None, }, "pixel_mask": None, + "original_size": None, } + batch_flag = [[idx] * len(slist["prompt"]) for idx, slist in enumerate(samples)] + batch_flag = sum(batch_flag, []) + batch_flag = torch.tensor(batch_flag, dtype=torch.int32) samples = expand_list(samples) + assert len(batch_flag) == len(samples["prompt"]) prompt_embedding, prompt_attention_mask = process_prompt_attention_mask( self.components.tokenizer, samples["prompt"], samples["prompt_embedding"], self.MAX_TTOKEN_LENGTH, + self.TEXT_TOKEN_FACTOR, ) - patch_size = self.components.transformer.config.patch_size + patch_size = self.PATCH_SIZE + image_rotary_emb = [self.rope(ei.unsqueeze(0)) for ei in samples["encoded_image"]] padded_latent, vtoken_attention_mask, pixel_mask = process_latent_attention_mask( samples["encoded_image"], patch_size ) # Store in batched_data batched_data["prompt_embedding"] = prompt_embedding - batched_data["attention_mask"]["text_embedding_attn_mask"] = prompt_attention_mask + batched_data["attention_kwargs"]["text_attn_mask"] = prompt_attention_mask batched_data["encoded_image"] = padded_latent - batched_data["attention_mask"]["latent_embedding_attn_mask"] = vtoken_attention_mask + batched_data["image_rotary_emb"] = image_rotary_emb + batched_data["attention_kwargs"]["latent_attn_mask"] = vtoken_attention_mask.reshape( + len(batch_flag), -1 + ) batched_data["pixel_mask"] = pixel_mask + batched_data["attention_kwargs"]["batch_flag"] = batch_flag + batched_data["original_size"] = torch.tensor( + [(img.height, img.width) for img in samples["image"]] + ) + return batched_data @override def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor: - patch_size = self.components.transformer.config.patch_size + dtype = self.get_training_dtype() prompt_embeds = batch["prompt_embedding"] latent = batch["encoded_image"] + image_rotary_emb = batch["image_rotary_emb"] batch_size, text_seqlen, text_embedding_dim = prompt_embeds.shape batch_size, num_channels, height, width = latent.shape - attn_mask = batch["attention_mask"] - latent_attention_mask = attn_mask["latent_embedding_attn_mask"].float() - latent_attention_mask_1d = latent_attention_mask.reshape(batch_size, -1) - vtoken_seq_len = torch.sum(latent_attention_mask_1d != -1, dim=1) - latent_shape = [] - - for i, data in enumerate(latent_attention_mask): - row_indices = torch.where(data[:, 0] == -1)[0] - if len(row_indices) > 0: - num_rows = row_indices[0].item() - else: - num_rows = data.shape[0] - - col_indices = torch.where(data[0, :] == -1)[0] - if len(col_indices) > 0: - num_cols = col_indices[0].item() - else: - num_cols = data.shape[1] - latent_shape.append((num_rows, num_cols)) - latent_shape = torch.tensor(latent_shape) - original_shape = latent_shape * self.DOWNSAMPLER_FACTOR * patch_size - assert torch.equal( - vtoken_seq_len.cpu(), - torch.prod((original_shape / (self.DOWNSAMPLER_FACTOR * patch_size)), dim=1), - ) + attention_kwargs = batch["attention_kwargs"] + latent_attention_mask = attention_kwargs["latent_attn_mask"].float() + assert latent_attention_mask.dim() == 2 + vtoken_seq_len = torch.sum(latent_attention_mask != 0, dim=1) - # prepare sigmas - scheduler = self.components.scheduler - sigmas = torch.linspace( - scheduler.sigma_min, - scheduler.sigma_max, - scheduler.config.num_train_timesteps, - device=self.accelerator.device, - ) + original_size = batch["original_size"] - m = (vtoken_seq_len / scheduler.config.base_image_seq_len) ** 0.5 - mu = m * scheduler.config.max_shift + scheduler.config.base_shift - mu = mu.unsqueeze(1) - sigmas = mu / (mu + (1 / sigmas - 1)) - sigmas = torch.flip(sigmas, dims=[1]) - sigmas = torch.cat([sigmas, torch.zeros((batch_size, 1), device=sigmas.device)], dim=1) - self.components.scheduler.sigmas = sigmas - - timestep = torch.randint( - 0, - scheduler.config.num_train_timesteps, - (batch_size,), - device=self.accelerator.device, - ) + num_train_timesteps = self.components.scheduler.config.num_train_timesteps + sigmas = self.get_sigmas(batch_size, vtoken_seq_len) + timestep = self.get_timestep(batch_size, num_train_timesteps) - noise = torch.randn_like(latent) - model_input, model_label = self.add_noise(latent, noise, timestep) - original_size = torch.tensor( - original_shape, - dtype=latent.dtype, - device=self.accelerator.device, - ) - target_size = torch.tensor( - original_shape, - dtype=latent.dtype, - device=self.accelerator.device, - ) + noise = torch.randn_like(latent, dtype=dtype) + model_input, model_label = self.add_noise(latent, noise, timestep, sigmas) + + original_size = original_size.to(dtype=dtype, device=self.accelerator.device) + target_size = original_size.clone().to(dtype=dtype, device=self.accelerator.device) crop_coords = torch.tensor( - [[0, 0] for _ in range(batch_size)], dtype=latent.dtype, device=self.accelerator.device + [[0, 0] for _ in range(batch_size)], dtype=dtype, device=self.accelerator.device ) - # FIXME: add attn support for cogview4 transformer noise_pred_cond = self.components.transformer( - hidden_states=model_input, - encoder_hidden_states=prompt_embeds, + hidden_states=model_input.to(dtype=dtype), + encoder_hidden_states=prompt_embeds.to(dtype=dtype), timestep=timestep, original_size=original_size, target_size=target_size, crop_coords=crop_coords, return_dict=False, - attention_mask=attn_mask, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, )[0] pixel_mask = batch["pixel_mask"] - pixel_mask[pixel_mask == 0] = 1 - pixel_mask[pixel_mask == -1] = 0 - loss = torch.mean(((noise_pred_cond - model_label) ** 2) * pixel_mask, dim=(1, 2, 3)) + loss = torch.sum(((noise_pred_cond - model_label) ** 2) * pixel_mask, dim=(1, 2, 3)) + loss = loss / torch.sum(pixel_mask, dim=(1, 2, 3)) loss = loss.mean() return loss - @override - def add_noise( - self, latent: torch.Tensor, noise: torch.Tensor, timestep: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Add noise to the latent vector based on the timestep. - - Args: - latent (torch.Tensor): The latent vector to add noise to. - noise (torch.Tensor): The noise tensor to add. - timestep (torch.Tensor): The current timestep. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The noisy latent vector that will be input to the model and the model label. - """ - num_train_timesteps = self.components.scheduler.config.num_train_timesteps - # note: sigmas in scheduler is arranged in reversed order - index = num_train_timesteps - timestep - scale_factor = ( - torch.gather(self.components.scheduler.sigmas, dim=1, index=index.unsqueeze(1)) - .squeeze(1) - .view(-1, 1, 1, 1) - .to(latent.device) - ) - - model_input = latent * (1 - scale_factor) + noise * scale_factor - model_label = noise - latent - return model_input, model_label - register("cogview4-6b", "lora-packing", Cogview4LoraPackingTrainer) diff --git a/src/cogkit/finetune/diffusion/models/cogview/cogview4/sft_trainer_packing.py b/src/cogkit/finetune/diffusion/models/cogview/cogview4/sft_trainer_packing.py new file mode 100644 index 0000000..e278cc0 --- /dev/null +++ b/src/cogkit/finetune/diffusion/models/cogview/cogview4/sft_trainer_packing.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + + +from cogkit.finetune import register + +from .lora_trainer_packing import Cogview4LoraPackingTrainer + + +class Cogview4SFTPackingTrainer(Cogview4LoraPackingTrainer): + pass + + +register("cogview4-6b", "sft-packing", Cogview4SFTPackingTrainer) diff --git a/src/cogkit/finetune/diffusion/schemas/args.py b/src/cogkit/finetune/diffusion/schemas/args.py index 92ac87b..8dad3d1 100644 --- a/src/cogkit/finetune/diffusion/schemas/args.py +++ b/src/cogkit/finetune/diffusion/schemas/args.py @@ -68,7 +68,7 @@ def parse_args(cls): parser.add_argument("--enable_tiling", action="store_true") # Packing - parser.add_argument("--enable_packing", action="store_true") + parser.add_argument("--enable_packing", type=lambda x: x.lower() == "true", default=False) # Validation parser.add_argument("--gen_fps", type=int, default=15) diff --git a/src/cogkit/finetune/diffusion/schemas/state.py b/src/cogkit/finetune/diffusion/schemas/state.py index 8823b7b..fc7bcbe 100644 --- a/src/cogkit/finetune/diffusion/schemas/state.py +++ b/src/cogkit/finetune/diffusion/schemas/state.py @@ -4,6 +4,8 @@ from pathlib import Path from typing import Any +import torch + from cogkit.finetune.base import BaseState @@ -15,9 +17,10 @@ class DiffusionState(BaseState): train_resolution: tuple[int, int, int] | tuple[int, int] # packing realted - max_vtoken_length: int | None = None training_seq_length: int | None = None validation_prompts: list[str] = [] validation_images: list[Path | None] = [] validation_videos: list[Path | None] = [] + + negative_prompt_embeds: torch.Tensor | None = None diff --git a/src/cogkit/finetune/diffusion/trainer.py b/src/cogkit/finetune/diffusion/trainer.py index edb914e..679d1ae 100644 --- a/src/cogkit/finetune/diffusion/trainer.py +++ b/src/cogkit/finetune/diffusion/trainer.py @@ -3,19 +3,18 @@ import torch import wandb -from accelerate.utils import ( - gather_object, -) +from accelerate import cpu_offload +from accelerate.utils import gather_object from PIL import Image from typing_extensions import override from cogkit.finetune.base import BaseTrainer from cogkit.samplers import NaivePackingSampler +from cogkit.utils import expand_list from diffusers.pipelines import DiffusionPipeline from diffusers.utils.export_utils import export_to_video from ..utils import ( - cast_training_params, free_memory, get_memory_statistics, unload_model, @@ -53,38 +52,36 @@ def prepare_models(self) -> None: @override def prepare_dataset(self) -> None: - # TODO: refactor later - match self.args.model_type: - case "i2v": - from cogkit.datasets import BaseI2VDataset, I2VDatasetWithResize - - dataset_cls = I2VDatasetWithResize - if self.args.enable_packing: - dataset_cls = BaseI2VDataset - raise NotImplementedError("Packing for I2V is not implemented") - - case "t2v": - from cogkit.datasets import BaseT2VDataset, T2VDatasetWithResize - - dataset_cls = T2VDatasetWithResize - if self.args.enable_packing: - dataset_cls = BaseT2VDataset - raise NotImplementedError("Packing for T2V is not implemented") - - case "t2i": - from cogkit.datasets import ( - BaseT2IDataset, - T2IDatasetWithResize, - T2IDatasetWithPacking, - ) - - dataset_cls = T2IDatasetWithResize - if self.args.enable_packing: - dataset_cls = BaseT2IDataset - dataset_cls_packing = T2IDatasetWithPacking - - case _: - raise ValueError(f"Invalid model type: {self.args.model_type}") + if self.args.model_type == "i2v": + from cogkit.datasets import BaseI2VDataset, I2VDatasetWithResize + + dataset_cls = I2VDatasetWithResize + if self.args.enable_packing: + dataset_cls = BaseI2VDataset + raise NotImplementedError("Packing for I2V is not implemented") + + elif self.args.model_type == "t2v": + from cogkit.datasets import BaseT2VDataset, T2VDatasetWithResize + + dataset_cls = T2VDatasetWithResize + if self.args.enable_packing: + dataset_cls = BaseT2VDataset + raise NotImplementedError("Packing for T2V is not implemented") + + elif self.args.model_type == "t2i": + from cogkit.datasets import ( + T2IDatasetWithFactorResize, + T2IDatasetWithPacking, + T2IDatasetWithResize, + ) + + dataset_cls = T2IDatasetWithResize + if self.args.enable_packing: + dataset_cls = T2IDatasetWithFactorResize + dataset_cls_packing = T2IDatasetWithPacking + + else: + raise ValueError(f"Invalid model type: {self.args.model_type}") additional_args = { "device": self.accelerator.device, @@ -103,18 +100,19 @@ def prepare_dataset(self) -> None: using_train=False, ) - # Prepare VAE and text encoder for encoding + ### Prepare VAE and text encoder for encoding self.components.vae.requires_grad_(False) self.components.text_encoder.requires_grad_(False) - self.components.vae = self.components.vae.to( - self.accelerator.device, dtype=self.state.weight_dtype - ) - self.components.text_encoder = self.components.text_encoder.to( - self.accelerator.device, dtype=self.state.weight_dtype - ) + self.components.vae.to(self.accelerator.device, dtype=self.state.weight_dtype) + if self.args.low_vram: # offload text encoder to CPU + cpu_offload(self.components.text_encoder, self.accelerator.device) + else: + self.components.text_encoder.to(self.accelerator.device, dtype=self.state.weight_dtype) + + ### Precompute embedding + self.logger.info("Precomputing embedding ...") + self.state.negative_prompt_embeds = self.get_negtive_prompt_embeds() - # Precompute latent for video and prompt embedding - self.logger.info("Precomputing latent for video and prompt embedding ...") for dataset in [self.train_dataset, self.test_dataset]: if dataset is None: continue @@ -130,10 +128,11 @@ def prepare_dataset(self) -> None: ... self.accelerator.wait_for_everyone() - self.logger.info("Precomputing latent for video and prompt embedding ... Done") + self.logger.info("Precomputing embedding ... Done") unload_model(self.components.vae) - unload_model(self.components.text_encoder) + if not self.args.low_vram: + unload_model(self.components.text_encoder) free_memory() if not self.args.enable_packing: @@ -172,13 +171,10 @@ def prepare_dataset(self) -> None: ) @override - def validate(self, step: int) -> None: - # TODO: refactor later + def validate(self, step: int, ckpt_path: str | None = None) -> None: self.logger.info("Starting validation") - accelerator = self.accelerator num_validation_samples = len(self.test_data_loader) - if num_validation_samples == 0: self.logger.warning("No validation samples found. Skipping validation.") return @@ -192,18 +188,23 @@ def validate(self, step: int) -> None: ) ##### Initialize pipeline ##### - pipe = self.initialize_pipeline() + pipe = self.initialize_pipeline(ckpt_path=ckpt_path) if self.state.using_deepspeed: # Can't using model_cpu_offload in deepspeed, # so we need to move all components in pipe to device self.move_components_to_device( - dtype=self.state.weight_dtype, ignore_list=["transformer"] + dtype=self.state.weight_dtype, + device=self.accelerator.device, + ignore_list=["transformer"], ) else: # if not using deepspeed, use model_cpu_offload to further reduce memory usage # Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage - pipe.enable_model_cpu_offload(device=self.accelerator.device) + if self.args.low_vram: + pipe.enable_sequential_cpu_offload(device=self.accelerator.device) + else: + pipe.enable_model_cpu_offload(device=self.accelerator.device) # Convert all model weights to training dtype # Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32 @@ -214,24 +215,25 @@ def validate(self, step: int) -> None: all_processes_artifacts = [] for i, batch in enumerate(self.test_data_loader): # only batch size = 1 is currently supported - prompt = batch.get("prompt", []) - prompt = prompt[0] if prompt else prompt + prompt = batch.get("prompt", None) + prompt = prompt[0] if prompt else None prompt_embedding = batch.get("prompt_embedding", None) - image = batch.get("image", []) - image = image[0] if image else image + image = batch.get("image", None) + image = image[0] if image else None encoded_image = batch.get("encoded_image", None) - video = batch.get("video", []) - video = video[0] if video else video + video = batch.get("video", None) + video = video[0] if video else None encoded_video = batch.get("encoded_video", None) self.logger.debug( - f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}", + f"Validating sample {i + 1}/{num_validation_samples} on process {self.accelerator.process_index}. Prompt: {prompt}", main_process_only=False, ) - validation_artifacts = self.validation_step( - { + val_res = self.validation_step( + pipe=pipe, + eval_data={ "prompt": prompt, "prompt_embedding": prompt_embedding, "image": image, @@ -239,108 +241,65 @@ def validate(self, step: int) -> None: "video": video, "encoded_video": encoded_video, }, - pipe, ) artifacts = {} - for ii, (artifact_type, artifact_value) in enumerate(validation_artifacts): - artifacts.update( - { - f"artifact_{ii}": { - "type": artifact_type, - "value": artifact_value, - } - } - ) - self.logger.debug( - f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}", - main_process_only=False, - ) - - for key, value in list(artifacts.items()): - artifact_type = value["type"] - artifact_value = value["value"] - if artifact_type not in ["text", "image", "video"] or artifact_value is None: - continue - - match artifact_type: - case "text": - extension = "txt" - case "image": - extension = "png" - case "video": - extension = "mp4" - validation_path = self.args.output_dir / "validation_res" / f"validation-{step}" - validation_path.mkdir(parents=True, exist_ok=True) - filename = f"artifact-process{accelerator.process_index}-batch{i}.{extension}" - filename = str(validation_path / filename) - - if artifact_type == "image": - self.logger.debug(f"Saving image to {filename}") - artifact_value.save(filename) - artifact_value = wandb.Image(filename) - elif artifact_type == "video": - self.logger.debug(f"Saving video to {filename}") - export_to_video(artifact_value, filename, fps=self.args.gen_fps) - artifact_value = wandb.Video(filename) - elif artifact_type == "text": - self.logger.debug(f"Saving text to {filename}") - with open(filename, "w") as f: - f.write(artifact_value) - artifact_value = str(artifact_value) - - all_processes_artifacts.append(artifact_value) + val_path = self.args.output_dir / "validation_res" / f"validation-{step}" + val_path.mkdir(parents=True, exist_ok=True) + filename = f"artifact-process{self.accelerator.process_index}-batch{i}" + + image = val_res.get("image", None) + video = val_res.get("video", None) + with open(val_path / f"{filename}.txt", "w") as f: + f.write(prompt) + if image: + fpath = str(val_path / f"{filename}.png") + image.save(fpath) + artifacts["image"] = wandb.Image(fpath, caption=prompt) + if video: + fpath = str(val_path / f"{filename}.mp4") + export_to_video(video, fpath, fps=self.args.gen_fps) + artifacts["video"] = wandb.Video(fpath, caption=prompt) + + all_processes_artifacts.append(artifacts) all_artifacts = gather_object(all_processes_artifacts) + all_artifacts = expand_list(all_artifacts) - if accelerator.is_main_process: + if self.accelerator.is_main_process: tracker_key = "validation" - for tracker in accelerator.trackers: + for tracker in self.accelerator.trackers: if tracker.name == "wandb": - text_artifacts = [ - artifact for artifact in all_artifacts if isinstance(artifact, str) - ] - image_artifacts = [ - artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image) - ] - video_artifacts = [ - artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video) - ] - tracker.log( - { - tracker_key: { - "texts": text_artifacts, - "images": image_artifacts, - "videos": video_artifacts, - }, - }, - step=step, - ) + tracker.log({tracker_key: all_artifacts}, step=step) ########## Clean up ########## if self.state.using_deepspeed: del pipe # Unload models except those needed for training - self.move_components_to_cpu(unload_list=self.UNLOAD_LIST) + self.move_components_to_device( + dtype=self.state.weight_dtype, device="cpu", ignore_list=["transformer"] + ) else: pipe.remove_all_hooks() del pipe # Load models except those not needed for training self.move_components_to_device( - dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST + dtype=self.state.weight_dtype, + device=self.accelerator.device, + ignore_list=self.UNLOAD_LIST, ) self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype) # Change trainable weights back to fp32 to keep with dtype after prepare the model - cast_training_params([self.components.transformer], dtype=torch.float32) + # cast_training_params([self.components.transformer], dtype=torch.float32) free_memory() - accelerator.wait_for_everyone() + self.accelerator.wait_for_everyone() ################################ memory_statistics = get_memory_statistics(self.logger) self.logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") - torch.cuda.reset_peak_memory_stats(accelerator.device) + torch.cuda.reset_peak_memory_stats(self.accelerator.device) torch.set_grad_enabled(True) self.components.transformer.train() @@ -354,13 +313,20 @@ def compute_loss(self, batch) -> torch.Tensor: raise NotImplementedError def collate_fn(self, samples: list[dict[str, Any]]): + """ + Note: This collate_fn function are used for both training and validation. + """ raise NotImplementedError - def initialize_pipeline(self) -> DiffusionPipeline: + def initialize_pipeline(self, ckpt_path: str | None = None) -> DiffusionPipeline: raise NotImplementedError def encode_text(self, text: str) -> torch.Tensor: - # shape of output text: [batch size, sequence length, embedding dimension] + # shape of output text: [sequence length, embedding dimension] + raise NotImplementedError + + def get_negtive_prompt_embeds(self) -> torch.Tensor: + # shape of output text: [sequence length, embedding dimension] raise NotImplementedError def encode_image(self, image: torch.Tensor) -> torch.Tensor: @@ -374,8 +340,27 @@ def encode_video(self, video: torch.Tensor) -> torch.Tensor: raise NotImplementedError def validation_step( - self, - ) -> list[tuple[str, Image.Image | list[Image.Image]]]: + self, pipe: DiffusionPipeline, eval_data: dict[str, Any] + ) -> dict[str, str | Image.Image | list[Image.Image]]: + """ + Perform a validation step using the provided pipeline and evaluation data. + + Args: + pipe: The diffusion pipeline instance used for validation. + eval_data: A dictionary containing data for validation, may include: + - "prompt": Text prompt for generation (str). + - "prompt_embedding": Pre-computed text embeddings. + - "image": Input image for image-to-image tasks. + - "encoded_image": Pre-computed image embeddings. + - "video": Input video for video tasks. + - "encoded_video": Pre-computed video embeddings. + + Returns: + A dictionary containing generated artifacts with keys: + - "text": Text data (str). + - "image": Generated image (PIL.Image.Image). + - "video": Generated video (list[PIL.Image.Image]). + """ raise NotImplementedError # ========== Packing related functions ========== diff --git a/src/cogkit/finetune/utils/__init__.py b/src/cogkit/finetune/utils/__init__.py index b7a1e50..8eaeafc 100644 --- a/src/cogkit/finetune/utils/__init__.py +++ b/src/cogkit/finetune/utils/__init__.py @@ -3,6 +3,5 @@ from .memory_utils import * # noqa from .optimizer_utils import * # noqa from .torch_utils import * # noqa -from .misc import * # noqa from .filters import * # noqa from .attn_mask import * # noqa diff --git a/src/cogkit/finetune/utils/attn_mask.py b/src/cogkit/finetune/utils/attn_mask.py index c2d6ed1..5c34935 100644 --- a/src/cogkit/finetune/utils/attn_mask.py +++ b/src/cogkit/finetune/utils/attn_mask.py @@ -1,16 +1,24 @@ -from typing import List, Tuple +import math +from typing import Any, List, Tuple import torch from transformers import AutoTokenizer +from diffusers.models.attention_processor import Attention from .filters import MeanFilter +def mask_assert(mask: torch.Tensor) -> None: + assert torch.all((mask == 0) | (mask == 1)), "mask contains values other than 0 or 1" + assert mask.dtype == torch.int32, "mask dtype should be torch.int32" + + def process_prompt_attention_mask( tokenizer: AutoTokenizer, prompt: List[str], prompt_embedding: List[torch.Tensor], max_ttoken_length: int, + pad_to_multiple_of: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Process prompt attention mask for training. @@ -25,25 +33,22 @@ def process_prompt_attention_mask( A tuple of (prompt_embedding, prompt_attention_mask) Attention mask values: - -1: Padding tokens (tokens to be ignored) - 0: Tokens belonging to the first micro-batch (when enable packing) + 0: Padding tokens (tokens to be ignored) + 1: True tokens (tokens to be attended) """ # Tokenize the prompt in the batch-level tokenized_prompt = tokenizer( prompt, - padding=True, + padding="longest", max_length=max_ttoken_length, truncation=True, add_special_tokens=True, + pad_to_multiple_of=pad_to_multiple_of, return_tensors="pt", ) prompt_attention_mask = tokenized_prompt.attention_mask num_samples = len(prompt) - ### Retrieve the unpadded value from prompt_embedding, then pad it to the max length in the batch - for idx, embedding in enumerate(prompt_embedding): - prompt_embedding[idx] = embedding[torch.where(prompt_attention_mask[idx] == 1)[0]] - token_length_list = [embedding.shape[0] for embedding in prompt_embedding] max_seqlen = max(token_length_list) assert max_seqlen == prompt_attention_mask.shape[1] @@ -55,11 +60,8 @@ def process_prompt_attention_mask( assert prompt_embedding.shape[0] == prompt_attention_mask.shape[0] == num_samples assert prompt_embedding.shape[1] == prompt_attention_mask.shape[1] == max_seqlen - ### Construct the prompt attention mask - prompt_attention_mask[ - prompt_attention_mask == 0 - ] = -1 # -1 means padding token type (tokens to be ignored) - prompt_attention_mask[prompt_attention_mask == 1] = 0 # 0 means tokens belong to micro-batch0 + prompt_attention_mask = prompt_attention_mask.to(torch.int32) + mask_assert(prompt_attention_mask) return prompt_embedding, prompt_attention_mask @@ -76,11 +78,11 @@ def process_latent_attention_mask( patch_size: Patch size for the transformer Returns: - A tuple of (padded_latent, vtoken_attention_mask) + A tuple of (padded_latent, vtoken_attention_mask, pixel_mask) Attention mask values: - -1: Padding tokens (tokens to be ignored) - 0: Tokens belonging to the first micro-batch (when enable packing) + 0: Padding tokens (tokens to be ignored) + 1: True tokens (tokens to be attended) """ num_samples = len(encoded_images) num_latent_channel = encoded_images[0].shape[0] @@ -90,8 +92,8 @@ def process_latent_attention_mask( max_latent_width = max([img.shape[2] for img in encoded_images]) # Ensure dimensions are divisible by patch_size - max_latent_height += max_latent_height % patch_size - max_latent_width += max_latent_width % patch_size + max_latent_height = math.ceil(max_latent_height / patch_size) * patch_size + max_latent_width = math.ceil(max_latent_width / patch_size) * patch_size # Create padded latent tensor and pixel mask padded_latent = torch.zeros( @@ -101,22 +103,12 @@ def process_latent_attention_mask( max_latent_width, dtype=torch.float32, ) - pixel_mask = ( - torch.ones( - num_samples, - num_latent_channel, - max_latent_height, - max_latent_width, - dtype=torch.float32, - ) - * -1 - ) + pixel_mask = padded_latent.clone() # Fill padded latent and set mask values for idx, latent in enumerate(encoded_images): padded_latent[idx, :, : latent.shape[1], : latent.shape[2]] = latent - # 0 means this pixel belongs to micro-batch0 - pixel_mask[idx, :, : latent.shape[1], : latent.shape[2]] = 0 + pixel_mask[idx, :, : latent.shape[1], : latent.shape[2]] = 1 # Ensure dimensions are divisible by patch_size assert max_latent_height % patch_size == 0 and max_latent_width % patch_size == 0 @@ -125,7 +117,17 @@ def process_latent_attention_mask( mean_filter = MeanFilter(kernel_size=patch_size, in_channels=num_latent_channel) vtoken_attention_mask = mean_filter(pixel_mask).squeeze(1) # remove channel dimension - # 0 means this vtoken belongs to micro-batch0 - vtoken_attention_mask[vtoken_attention_mask != -1] = 0 + vtoken_attention_mask[vtoken_attention_mask != 1] = 0 + + pixel_mask = pixel_mask.to(torch.int32) + vtoken_attention_mask = vtoken_attention_mask.to(torch.int32) + mask_assert(pixel_mask) + mask_assert(vtoken_attention_mask) return padded_latent, vtoken_attention_mask, pixel_mask + + +def replace_attn_processor(model: torch.nn.Module, attn_processor_obj: Any) -> None: + for name, submodule in model.named_modules(): + if isinstance(submodule, Attention): + submodule.processor = attn_processor_obj diff --git a/src/cogkit/finetune/utils/filters.py b/src/cogkit/finetune/utils/filters.py index bcb17dc..c0b8d63 100644 --- a/src/cogkit/finetune/utils/filters.py +++ b/src/cogkit/finetune/utils/filters.py @@ -15,7 +15,7 @@ def __init__(self, kernel_size: int, in_channels: int): self.weight = nn.Parameter(weight, requires_grad=False) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: batch_size, channels, height, width = x.shape if channels != self.in_channels: diff --git a/src/cogkit/finetune/utils/misc.py b/src/cogkit/finetune/utils/misc.py deleted file mode 100644 index 3c1c7f0..0000000 --- a/src/cogkit/finetune/utils/misc.py +++ /dev/null @@ -1,13 +0,0 @@ -from collections import defaultdict -from typing import Any - - -def expand_list(list_of_dict: list[dict[str, list[Any]]]) -> dict[str, list[Any]]: - """ - Expand a list of dictionaries to a dictionary of lists. - """ - result = defaultdict(list) - for d in list_of_dict: - for key, values in d.items(): - result[key].extend(values) - return dict(result) diff --git a/src/cogkit/utils/__init__.py b/src/cogkit/utils/__init__.py index 51aa90e..b2ec328 100644 --- a/src/cogkit/utils/__init__.py +++ b/src/cogkit/utils/__init__.py @@ -3,8 +3,14 @@ from cogkit.utils.diffusion_pipeline import get_pipeline_meta from cogkit.utils.dtype import cast_to_torch_dtype -from cogkit.utils.lora import load_lora_checkpoint, unload_lora_checkpoint -from cogkit.utils.misc import guess_generation_mode +from cogkit.utils.lora import ( + load_lora_checkpoint, + unload_lora_checkpoint, + inject_lora, + save_lora, + unload_lora, +) +from cogkit.utils.misc import guess_generation_mode, flatten_dict, expand_list from cogkit.utils.path import mkdir, resolve_path from cogkit.utils.prompt import convert_prompt from cogkit.utils.random import rand_generator @@ -15,10 +21,15 @@ "cast_to_torch_dtype", "load_lora_checkpoint", "unload_lora_checkpoint", + "inject_lora", + "save_lora", + "unload_lora", "guess_generation_mode", "mkdir", "resolve_path", "rand_generator", "load_pipeline", "convert_prompt", + "flatten_dict", + "expand_list", ] diff --git a/src/cogkit/utils/lora.py b/src/cogkit/utils/lora.py index 2c09d97..db9e142 100644 --- a/src/cogkit/utils/lora.py +++ b/src/cogkit/utils/lora.py @@ -1,19 +1,127 @@ # -*- coding: utf-8 -*- +""" +LoRA utility functions for model fine-tuning. + +This module provides a user-friendly interface for working with LoRA adapters +based on Huggings PEFT (Parameter-Efficient Fine-Tuning) library. +It simplifies the process of injecting, saving, loading, and unloading LoRA +adapters for transformer models. + +For more details, refer to: https://huggingface.co/docs/peft/developer_guides/low_level_api +""" + +from pathlib import Path + +from peft import ( + LoraConfig, + get_peft_model_state_dict, + inject_adapter_in_model, + set_peft_model_state_dict, +) +from safetensors.torch import load_file, save_file from diffusers.loaders import CogVideoXLoraLoaderMixin, CogView4LoraLoaderMixin +from diffusers.utils import recurse_remove_peft_layers + +# Standard filename for LoRA adapter weights +_LORA_WEIGHT_NAME = "adapter_model.safetensors" + + +def _get_lora_config() -> LoraConfig: + return LoraConfig( + r=128, + lora_alpha=64, + init_lora_weights=True, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + ) + + +def inject_lora(model, lora_dir_or_state_dict: str | Path | None = None) -> None: + """ + Inject LoRA adapters into the model. + + This function adds LoRA layers to the specified model. If a LoRA checkpoint + is provided, it will load the weights from that checkpoint. Otherwise, it + will initialize the LoRA weights randomly. + + Args: + model: The model to inject LoRA adapters into + lora_dir_or_state_dict: Path to a LoRA checkpoint directory, a state dict, + or None for random initialization + """ + transformer_lora_config = _get_lora_config() + inject_adapter_in_model(transformer_lora_config, model) + if lora_dir_or_state_dict is None: + return + + if isinstance(lora_dir_or_state_dict, str) or isinstance(lora_dir_or_state_dict, Path): + lora_dir = Path(lora_dir_or_state_dict) + lora_fpath = lora_dir / _LORA_WEIGHT_NAME + assert lora_dir.exists(), f"LORA checkpoint directory {lora_dir} does not exist" + assert lora_fpath.exists(), f"LORA checkpoint file {lora_fpath} does not exist" + + peft_state_dict = load_file(lora_fpath, device="cpu") + else: + peft_state_dict = lora_dir_or_state_dict + + set_peft_model_state_dict(model, peft_state_dict) + + +def save_lora(model, lora_dir: str | Path) -> None: + """ + Save the LoRA adapter weights from a model to disk. + + Args: + model: The model containing LoRA adapters to save + lora_dir: Directory path where the LoRA weights will be saved + + Raises: + ValueError: If no LoRA weights are found in the model + """ + lora_dir = Path(lora_dir) + peft_state_dict = get_peft_model_state_dict(model) + if not peft_state_dict: + raise ValueError("No LoRA weights found in the model") + + lora_fpath = lora_dir / _LORA_WEIGHT_NAME + save_file(peft_state_dict, lora_fpath, metadata={"format": "pt"}) + + +def unload_lora(model) -> None: + """ + Remove all LoRA adapters from the model. + + This function recursively removes all PEFT (LoRA) layers from the model, + returning it to its original state without the adapters. + + Args: + model: The model from which to remove LoRA adapters + """ + recurse_remove_peft_layers(model) def load_lora_checkpoint( pipeline: CogVideoXLoraLoaderMixin | CogView4LoraLoaderMixin, - lora_model_id_or_path: str, - lora_scale: float = 1.0, + lora_dir: str | Path, ) -> None: - pipeline.load_lora_weights(lora_model_id_or_path, lora_scale=lora_scale) - # pipeline.fuse_lora(components=["transformer"], lora_scale=lora_scale) + """ + Load a LoRA checkpoint into a pipeline. + + This is a convenience function that injects LoRA adapters into the transformer + component of the specified pipeline and loads the weights from the checkpoint. + """ + lora_dir = Path(lora_dir) + inject_lora(pipeline.transformer, lora_dir) def unload_lora_checkpoint( pipeline: CogVideoXLoraLoaderMixin | CogView4LoraLoaderMixin, ) -> None: - pipeline.unload_lora_weights() + """ + Remove LoRA adapters from a pipeline. + + This is a convenience function that removes all LoRA adapters from the + transformer component of the specified pipeline. + """ + unload_lora(pipeline.transformer) diff --git a/src/cogkit/utils/misc.py b/src/cogkit/utils/misc.py index 087d10c..0af5c98 100644 --- a/src/cogkit/utils/misc.py +++ b/src/cogkit/utils/misc.py @@ -2,6 +2,7 @@ from pathlib import Path +from typing import Any from PIL import Image @@ -111,3 +112,83 @@ def guess_generation_mode( ) return GenerationMode.TextToVideo + + +def flatten_dict(d: dict[str, Any], ignore_none: bool = False) -> dict[str, Any]: + """ + Flattens a nested dictionary into a single layer dictionary. + + Args: + d: The dictionary to flatten + ignore_none: If True, keys with None values will be omitted + + Returns: + A flattened dictionary + + Raises: + ValueError: If there are duplicate keys across nested dictionaries + + Examples: + >>> flatten_dict({"a": 1, "b": {"c": 2, "d": {"e": 3}}, "f": None}) + {"a": 1, "c": 2, "e": 3, "f": None} + + >>> flatten_dict({"a": 1, "b": {"c": 2, "d": {"e": 3}}, "f": None}, ignore_none=True) + {"a": 1, "c": 2, "e": 3} + + >>> flatten_dict({"a": 1, "b": {"a": 2}}) + ValueError: Duplicate key 'a' found in nested dictionary + """ + result = {} + + def _flatten(current_dict, result_dict): + for key, value in current_dict.items(): + if value is None and ignore_none: + continue + + if isinstance(value, dict): + _flatten(value, result_dict) + else: + if key in result_dict: + raise ValueError(f"Duplicate key '{key}' found in nested dictionary") + result_dict[key] = value + + _flatten(d, result) + return result + + +def expand_list(dicts: list[dict[str, Any]]) -> dict[str, list[Any]]: + """ + Converts a list of dictionaries into a dictionary of lists. + + For each key in the dictionaries, collects all values corresponding to that key + into a list. + + Args: + dicts: A list of dictionaries + + Returns: + A dictionary where each key maps to a list of values from the input dictionaries + + Examples: + >>> expand_list([{"a": 1, "b": 2}, {"a": 3, "b": 4, "c": 5}]) + {"a": [1, 3], "b": [2, 4], "c": [5]} + + >>> expand_list([{"x": "value1"}, {"y": "value2"}, {"x": "value3"}]) + {"x": ["value1", "value3"], "y": ["value2"]} + + >>> expand_list([{"x": ["value1", "value2"]}, {"y": "value3"}, {"x": ["value4"]}]) + {"x": ["value1", "value2", "value4"], "y": ["value3"]} + """ + result = {} + + for d in dicts: + for key, value in d.items(): + if key not in result: + result[key] = [] + + if isinstance(value, list): + result[key].extend(value) + else: + result[key].append(value) + + return result diff --git a/tests/test_expand_list.py b/tests/test_expand_list.py index 6f6677e..45f941c 100644 --- a/tests/test_expand_list.py +++ b/tests/test_expand_list.py @@ -1,6 +1,4 @@ -import pytest - -from cogkit.finetune.utils import expand_list +from cogkit.utils import expand_list class TestExpandList: @@ -29,12 +27,22 @@ def test_mixed_keys(self): expected_output = {"a": [1, 3], "b": [2, 4], "c": [5]} assert expand_list(input_data) == expected_output - def test_non_list_values(self): - input_data = [{"a": 1}] - with pytest.raises(TypeError): - expand_list(input_data) - def test_empty_values(self): input_data = [{"a": [], "b": [1]}, {"a": [2], "b": []}] expected_output = {"a": [2], "b": [1]} assert expand_list(input_data) == expected_output + + def test_only_list_values(self): + input_data = [{"a": ["x", "y"]}, {"b": ["z"]}, {"a": ["w"]}] + expected_output = {"a": ["x", "y", "w"], "b": ["z"]} + assert expand_list(input_data) == expected_output + + def test_only_single_values(self): + input_data = [{"a": 1, "b": 2}, {"a": 3, "c": 4}] + expected_output = {"a": [1, 3], "b": [2], "c": [4]} + assert expand_list(input_data) == expected_output + + def test_mixed_list_and_single_values(self): + input_data = [{"a": ["x", "y"], "b": 1}, {"a": 2, "b": ["z", "w"]}] + expected_output = {"a": ["x", "y", 2], "b": [1, "z", "w"]} + assert expand_list(input_data) == expected_output diff --git a/tools/resize_img.py b/tools/resize_img.py new file mode 100755 index 0000000..1223dee --- /dev/null +++ b/tools/resize_img.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +import os +import sys +import argparse +from PIL import Image + + +def resize_images(directory, max_width, max_height): + max_pixels = max_width * max_height + + for filename in os.listdir(directory): + if filename.lower().endswith((".png", ".jpg", ".jpeg")): + file_path = os.path.join(directory, filename) + + try: + with Image.open(file_path) as img: + width, height = img.size + current_pixels = width * height + + if current_pixels > max_pixels: + # Calculate new dimensions while maintaining aspect ratio + ratio = (max_pixels / current_pixels) ** 0.5 + new_width = int(width * ratio) + new_height = int(height * ratio) + + print( + f"Resizing {filename}: {width}x{height} ({current_pixels} pixels) -> {new_width}x{new_height} ({new_width * new_height} pixels)" + ) + + # Resize and save with original format + resized_img = img.resize((new_width, new_height), Image.LANCZOS) + resized_img.save(file_path, quality=95) + else: + print( + f"Skipping {filename}: {width}x{height} ({current_pixels} pixels) <= {max_pixels} pixels" + ) + except Exception as e: + print(f"Error processing {filename}: {str(e)}") + + +def main(): + parser = argparse.ArgumentParser( + description="Resize images in a directory if they exceed a pixel count threshold" + ) + parser.add_argument( + "--resolution", help="Maximum resolution in format HEIGHTxWIDTH (e.g. 1080x1920)" + ) + parser.add_argument("--imgdir", help="Directory containing images to process") + + args = parser.parse_args() + + if not args.resolution or "x" not in args.resolution: + print("Error: --resolution must be specified in format HEIGHTxWIDTH (e.g. 1080x1920)") + sys.exit(1) + + try: + height, width = map(int, args.resolution.lower().split("x")) + height -= 16 + width -= 16 + except ValueError: + print("Error: Invalid resolution format. Use HEIGHTxWIDTH (e.g. 1080x1920)") + sys.exit(1) + + if not os.path.isdir(args.imgdir): + print(f"Error: {args.imgdir} is not a valid directory") + sys.exit(1) + + resize_images(args.imgdir, width, height) + + +if __name__ == "__main__": + main()