diff --git a/OmniGen/model.py b/OmniGen/model.py index 5389931..ee07299 100644 --- a/OmniGen/model.py +++ b/OmniGen/model.py @@ -1,5 +1,8 @@ # The code is revised from DiT import os +import gc +import warnings +from pathlib import Path import torch import torch.nn as nn import numpy as np @@ -7,13 +10,19 @@ from typing import Dict from diffusers.loaders import PeftAdapterMixin +from diffusers.utils import logging from timm.models.vision_transformer import PatchEmbed, Attention, Mlp from huggingface_hub import snapshot_download from safetensors.torch import load_file +from accelerate import init_empty_weights +from transformers import BitsAndBytesConfig from OmniGen.transformer import Phi3Config, Phi3Transformer +from OmniGen.utils import quantize_bnb +logger = logging.get_logger(__name__) + def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) @@ -162,6 +171,7 @@ def __init__( pos_embed_max_size: int = 192, ): super().__init__() + self.config = transformer_config self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size @@ -185,22 +195,78 @@ def __init__( self.llm = Phi3Transformer(config=transformer_config) self.llm.config.use_cache = False + + # bnb quantized models cannot easily be offloaded or recast + self.quantized = False + self.dtype = None @classmethod - def from_pretrained(cls, model_name): - if not os.path.exists(model_name): + def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = None, quantization_config: BitsAndBytesConfig = None, low_cpu_mem_usage: bool = True,): + model_path = Path(model_name) + config_loc = model_name # these only diverge when model_name is *.safetensors or *.pt file + + if model_path.exists(): + if model_path.is_dir(): + if (weights_loc := list(model_path.glob('*.safetensors'))): + model_path = weights_loc[0] + elif (weights_loc := list(model_path.glob('*.pt'))): + model_path = weights_loc[0] + else: + raise FileNotFoundError(f'No .safetensors or .pt model weights found in {model_path.as_posix()!r}') + else: + logger.info("Loading model weights from file. Using default config from 'Shitao/OmniGen-v1'.") + config_loc = "Shitao/OmniGen-v1" + else: cache_folder = os.getenv('HF_HUB_CACHE') - model_name = snapshot_download(repo_id=model_name, - cache_dir=cache_folder, - ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) - config = Phi3Config.from_pretrained(model_name) - model = cls(config) - if os.path.exists(os.path.join(model_name, 'model.safetensors')): - print("Loading safetensors") - ckpt = load_file(os.path.join(model_name, 'model.safetensors')) + model_path = snapshot_download(repo_id=model_name, cache_dir=cache_folder, + ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) + + # assume hub files are always .safetensors + model_path = next(Path(model_path).glob('*.safetensors')) + + ckpt = (load_file(model_path, 'cpu') if model_path.suffix == '.safetensors' else + torch.load(model_path, map_location='cpu')) + + config = Phi3Config.from_pretrained(config_loc) + # avoid inadvertently leaving the weights as float32 + if dtype is None: + dtype = config.torch_dtype + + if hasattr(config, 'quantization_config'): + if quantization_config is not None: + # from: diffusers.quantizers.auto + warnings.warn( + "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading" + " already has a `quantization_config` attribute. The `quantization_config` from the model will be used." + ) + + config.quantization_config.pop("quant_method",None) # prevent unused keys warning + quantization_config = BitsAndBytesConfig.from_dict(config.quantization_config) + + if low_cpu_mem_usage: + with init_empty_weights(): + model = cls(config) + + if quantization_config: + model = quantize_bnb(model, ckpt, quantization_config=quantization_config, dtype=dtype) + model.quantized = True + model.config.quantization_config = quantization_config + else: + model.load_state_dict(ckpt, assign=True) else: - ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu') - model.load_state_dict(ckpt) + if quantization_config: + raise ValueError('Quantization not supported for `low_cpu_mem_usage=False`.') + + model = cls(config) + model.load_state_dict(ckpt) + + + # determine dtype via x_emb bias since as a Conv2d bias, it should never be quantized + model.dtype = model.x_embedder.proj.bias.dtype + + del ckpt + torch.cuda.empty_cache() + gc.collect() return model def initialize_weights(self): diff --git a/OmniGen/pipeline.py b/OmniGen/pipeline.py index 572452b..950ea33 100644 --- a/OmniGen/pipeline.py +++ b/OmniGen/pipeline.py @@ -1,6 +1,7 @@ import os import inspect -from typing import Any, Callable, Dict, List, Optional, Union +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union, Literal import gc from PIL import Image @@ -17,6 +18,7 @@ scale_lora_layers, unscale_lora_layers, ) +from transformers import BitsAndBytesConfig from safetensors.torch import load_file from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler @@ -41,6 +43,15 @@ ``` """ +def best_available_device(): + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!") + device = torch.device("cpu") + return device class OmniGenPipeline: def __init__( @@ -55,14 +66,10 @@ def __init__( self.processor = processor self.device = device - if device is None: - if torch.cuda.is_available(): - self.device = torch.device("cuda") - elif torch.backends.mps.is_available(): - self.device = torch.device("mps") - else: - logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!") - self.device = torch.device("cpu") + if self.device is None: + self.device = best_available_device() + elif isinstance(self.device, str): + self.device = torch.device(self.device) # self.model.to(torch.bfloat16) self.model.eval() @@ -71,28 +78,46 @@ def __init__( self.model_cpu_offload = False @classmethod - def from_pretrained(cls, model_name, vae_path: str=None): - if not os.path.exists(model_name) or (not os.path.exists(os.path.join(model_name, 'model.safetensors')) and model_name == "Shitao/OmniGen-v1"): - # logger.info("Model not found, downloading...") - print("Model not found, downloading...") - cache_folder = os.getenv('HF_HUB_CACHE') - model_name = snapshot_download(repo_id=model_name, - cache_dir=cache_folder, - ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt']) - # logger.info(f"Downloaded model to {model_name}") - print(f"Downloaded model to {model_name}") - model = OmniGen.from_pretrained(model_name) - processor = OmniGenProcessor.from_pretrained(model_name) - - if os.path.exists(os.path.join(model_name, "vae")): - vae = AutoencoderKL.from_pretrained(os.path.join(model_name, "vae")) - elif vae_path is not None: - vae = AutoencoderKL.from_pretrained(vae_path).to(device) - else: - logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF") - vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device) + def from_pretrained(cls, model_name, vae_path: str=None, device=None, low_cpu_mem_usage=True, **kwargs): + pretrained_path = Path(model_name) + + # XXX: Consider renaming 'model' to 'transformer' conform to diffusers pipeline syntax + model = kwargs.get('model', None) + processor = kwargs.get('processor', None) + vae = kwargs.get('vae', None) - return cls(vae, model, processor) + # NOTE: should technically allow delayed component inits via model/vae = None, but seems like more of a footgun than it's worth at this point + + if not pretrained_path.exists(): + ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt'] + + if model is not None: + ignore_patterns.append('model.safetensors') # avoid downloading bf16 model if passing existing model + + logger.info("Model not found, downloading...") + cache_folder = os.getenv('HF_HUB_CACHE') + pretrained_path = Path(snapshot_download(repo_id=model_name, cache_dir=cache_folder, ignore_patterns=ignore_patterns)) + logger.info(f"Downloaded model to {pretrained_path}") + + if model is None: + model = OmniGen.from_pretrained(pretrained_path, dtype=torch.bfloat16, quantization_config=None, low_cpu_mem_usage=low_cpu_mem_usage) + + model = model.requires_grad_(False).eval() + + if processor is None: + processor = OmniGenProcessor.from_pretrained(model_name) + + if vae is None: + if vae_path is None: + vae_path = pretrained_path.joinpath("vae") + + if not os.path.exists(vae_path): + logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF") + vae_path = "stabilityai/sdxl-vae" + + vae = AutoencoderKL.from_pretrained(vae_path, low_cpu_mem_usage=low_cpu_mem_usage) + + return cls(vae, model, processor, device) def merge_lora(self, lora_path: str): model = PeftModel.from_pretrained(self.model, lora_path) @@ -123,7 +148,8 @@ def move_to_device(self, data): def enable_model_cpu_offload(self): self.model_cpu_offload = True - self.model.to("cpu") + if not self.model.quantized: + self.model.to("cpu") self.vae.to("cpu") torch.cuda.empty_cache() # Clear VRAM gc.collect() # Run garbage collection to free system RAM @@ -212,7 +238,13 @@ def __call__( # set model and processor if max_input_image_size != self.processor.max_image_size: self.processor = OmniGenProcessor(self.processor.text_tokenizer, max_image_size=max_input_image_size) - self.model.to(dtype) + + if not self.model.quantized: + self.model.dtype = dtype + self.model.to(dtype) + + #self.vae.to(dtype) # Uncomment this line to allow bfloat16 VAE + if offload_model: self.enable_model_cpu_offload() else: @@ -234,7 +266,7 @@ def __call__( else: generator = None latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device, generator=generator) - latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype) + latents = torch.cat([latents]*(1+num_cfg), 0).to(self.model.dtype) if input_images is not None and self.model_cpu_offload: self.vae.to(self.device) input_img_latents = [] @@ -242,12 +274,12 @@ def __call__( for temp_pixel_values in input_data['input_pixel_values']: temp_input_latents = [] for img in temp_pixel_values: - img = self.vae_encode(img.to(self.device), dtype) + img = self.vae_encode(img.to(self.vae.device, self.vae.dtype), self.model.dtype) temp_input_latents.append(img) input_img_latents.append(temp_input_latents) else: for img in input_data['input_pixel_values']: - img = self.vae_encode(img.to(self.device), dtype) + img = self.vae_encode(img.to(self.vae.device, self.vae.dtype), self.model.dtype) input_img_latents.append(img) if input_images is not None and self.model_cpu_offload: self.vae.to('cpu') @@ -263,7 +295,7 @@ def __call__( img_cfg_scale=img_guidance_scale, use_img_cfg=use_img_guidance, use_kv_cache=use_kv_cache, - offload_model=offload_model, + offload_model=(offload_model and not self.model.quantized), ) if separate_cfg_infer: @@ -271,28 +303,37 @@ def __call__( else: func = self.model.forward_with_cfg - if self.model_cpu_offload: + if self.model_cpu_offload and not self.model.quantized: for name, param in self.model.named_parameters(): if 'layers' in name and 'layers.0' not in name: - param.data = param.data.cpu() + param.data = param.data.to('cpu') else: param.data = param.data.to(self.device) for buffer_name, buffer in self.model.named_buffers(): setattr(self.model, buffer_name, buffer.to(self.device)) + torch.cuda.empty_cache() + gc.collect() # else: # self.model.to(self.device) scheduler = OmniGenScheduler(num_steps=num_inference_steps) + if latents.dtype == torch.float16: + # Continue to monitor. If _clip_val never changes, can remove scheduler autoset func and just hardcode clip val here. + #self.model.llm.set_clip_val(2**16-32 - 2*32) # hardcode clip val + # dry run the inputs, adjusting the clip bounds as necessary + scheduler._fp16_clip_autoset(self.model.llm, latents, func, model_kwargs) samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache, offload_kv_cache=offload_kv_cache) samples = samples.chunk((1+num_cfg), dim=0)[0] if self.model_cpu_offload: - self.model.to('cpu') + if not self.model.quantized: + self.model.to("cpu") + torch.cuda.empty_cache() gc.collect() self.vae.to(self.device) - samples = samples.to(torch.float32) + samples = samples.to(self.vae.dtype) if self.vae.config.shift_factor is not None: samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor else: diff --git a/OmniGen/scheduler.py b/OmniGen/scheduler.py index ffa99cd..cf294f1 100644 --- a/OmniGen/scheduler.py +++ b/OmniGen/scheduler.py @@ -1,10 +1,13 @@ +import copy from tqdm import tqdm from typing import Optional, Dict, Any, Tuple, List import gc import torch from transformers.cache_utils import Cache, DynamicCache, OffloadedCache +from diffusers.utils import logging +logger = logging.get_logger(__name__) class OmniGenCache(DynamicCache): @@ -121,8 +124,43 @@ def __init__(self, num_steps: int=50, time_shifting_factor: int=1): t = torch.linspace(0, 1, num_steps+1) t = t / (t + time_shifting_factor - time_shifting_factor * t) self.sigma = t - + @torch.no_grad() + def _fp16_clip_autoset(self, model_llm, z, func, model_kwargs): + '''Recursively search for a minimal clipping value for fp16 stability''' + fp16_max_repr = torch.finfo(torch.float16).max # fp16 max representable: ±2^16-32 + timesteps = torch.full(size=(len(z), ), fill_value=self.sigma[0], device=z.device) + _buff_expon = model_kwargs.pop('_buff_expon', None) # temp local recursion var + + if _buff_expon is None: + # fp16 overflows at ±2^16-16 with largest repr being ±2^16-32. repr vals occur at intervals of 32 for nums > 2^15. + # Prelim tests show an additional buffer of at least 2 repr values is needed for stability; why is presently unclear. + # If this continues to hold true, this function can be deleted and replaced with 1 line in pipeline. + clip_val = fp16_max_repr - 2*32 # = 2**6 = (-2,+2 buffer vals) + if model_llm._clip_val is None or model_llm._clip_val > clip_val: + model_llm.set_clip_val(clip_val) + logger.debug(f'set initial clamp: (+-){clip_val} ...') + else: + clip_val = fp16_max_repr - 2**_buff_expon + model_llm.set_clip_val(clip_val) # clamp (-clip_val, +clip_val) + + try: + _model_kwargs = copy.deepcopy(model_kwargs) + _model_kwargs['use_kv_cache']=False # no cache while searching + _, _ = func(z.clone(), timesteps, past_key_values=None, **_model_kwargs) + except OverflowError: + if _buff_expon is None: + _buff_expon = 6 # start at 2**(6 + 1) (-4,+4 buffer vals) + logger.info('FP16 overflow, searching for clamp bounds...') + + if _buff_expon < 15: # stop at 2**15 (-1024,+1024 buffer vals) + _buff_expon += 1 + # each iter, double the representable value buffer capacity for both min and max + model_kwargs['_buff_expon'] = _buff_expon + logger.debug(f'trying clamp: (+-){fp16_max_repr - 2**(_buff_expon)} ...') + return self._fp16_clip_autoset(model_llm, z, func, model_kwargs) + raise OverflowError('Numerical overflow, unable to find suitable clipping bounds.') + def crop_kv_cache(self, past_key_values, num_tokens_for_img): # return crop_past_key_values = () diff --git a/OmniGen/transformer.py b/OmniGen/transformer.py index a2672a4..e843758 100644 --- a/OmniGen/transformer.py +++ b/OmniGen/transformer.py @@ -29,6 +29,11 @@ class Phi3Transformer(Phi3Model): Args: config: Phi3Config """ + _clip_val: float = None # fp16: ~ (2**16 - 2**7) + + def set_clip_val(self, clip_val:float=None): + self._clip_val = abs(clip_val) + def prefetch_layer(self, layer_idx: int, device: torch.device): "Starts prefetching the next layer cache" with torch.cuda.stream(self.prefetch_stream): @@ -42,7 +47,7 @@ def evict_previous_layer(self, layer_idx: int): for name, param in self.layers[prev_layer_idx].named_parameters(): param.data = param.data.to("cpu", non_blocking=True) - def get_offlaod_layer(self, layer_idx: int, device: torch.device): + def get_offload_layer(self, layer_idx: int, device: torch.device): # init stream if not hasattr(self, "prefetch_stream"): self.prefetch_stream = torch.cuda.Stream() @@ -137,6 +142,8 @@ def forward( for decoder_layer in self.layers: layer_idx += 1 + if self._clip_val is not None: + hidden_states.clamp_(-self._clip_val, self._clip_val) if output_hidden_states: all_hidden_states += (hidden_states,) @@ -153,7 +160,7 @@ def forward( ) else: if offload_model and not self.training: - self.get_offlaod_layer(layer_idx, device=inputs_embeds.device) + self.get_offload_layer(layer_idx, device=inputs_embeds.device) layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, @@ -173,7 +180,8 @@ def forward( all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) - + if hidden_states.isnan().any(): + raise OverflowError('Numerical Overflow: hidden states NaNs') # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/OmniGen/utils.py b/OmniGen/utils.py index 67a64e8..e6b2349 100644 --- a/OmniGen/utils.py +++ b/OmniGen/utils.py @@ -1,9 +1,14 @@ +import gc import logging from PIL import Image import torch import numpy as np +from transformers import BitsAndBytesConfig +from transformers.quantizers import AutoHfQuantizer +from transformers.integrations import replace_with_bnb_linear, set_module_quantized_tensor_to_device, get_keys_to_not_convert + def create_logger(logging_dir): """ Create a logger that writes to a log file and stdout. @@ -108,3 +113,44 @@ def vae_encode_list(vae, x, weight_dtype): latents.append(img) return latents + + +@torch.no_grad() +def quantize_bnb(meta_model, state_dict:dict, quantization_config:BitsAndBytesConfig, pre_quantized=None, dtype=None): + if pre_quantized is None: + if quantization_config.load_in_4bit: + pre_quantized = any('bitsandbytes__' in k for k in state_dict) + elif quantization_config.load_in_8bit: + pre_quantized = any('weight_format' in k for k in state_dict) + + if quantization_config.llm_int8_skip_modules is None: + quantization_config.llm_int8_skip_modules = get_keys_to_not_convert(meta_model.llm) # ['norm'] + + quantizer = AutoHfQuantizer.from_config(quantization_config, pre_quantized=pre_quantized) + + meta_model.eval() + meta_model.requires_grad_(False) + + model = meta_model + + quantizer.preprocess_model(model, device_map=None,) + + # iterate the model keys, otherwise quantized state dict will throws errors + for param_name in model.state_dict(): + param = state_dict[param_name] + if not pre_quantized: + param = param.to(dtype) + + if not quantizer.check_quantized_param(model, param, param_name, state_dict): + set_module_quantized_tensor_to_device(model, param_name, device=0, value=param) + else: + quantizer.create_quantized_param(model, param, param_name, target_device=0, state_dict=state_dict) + + del state_dict[param_name], param + + model = quantizer.postprocess_model(model) + + del state_dict + torch.cuda.empty_cache() + gc.collect() + return model \ No newline at end of file diff --git a/app.py b/app.py index ba87673..3dbe96b 100644 --- a/app.py +++ b/app.py @@ -4,13 +4,9 @@ import argparse import random import spaces +from transformers import BitsAndBytesConfig - -from OmniGen import OmniGenPipeline - -pipe = OmniGenPipeline.from_pretrained( - "Shitao/OmniGen-v1" -) +from OmniGen import OmniGenPipeline, OmniGen @spaces.GPU(duration=180) def generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, separate_cfg_infer, offload_model, @@ -370,6 +366,8 @@ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_ with gr.Column(): with gr.Column(): + # quantization = gr.Radio(["4bit (NF4)", "8bit", "None (bf16)"], label="bitsandbytes quantization", value="4bit (NF4)") + # quantization.input(change_quantization, inputs=quantization, trigger_mode="once", concurrency_limit=1) # output image output_image = gr.Image(label="Output Image") save_images = gr.Checkbox(label="Save generated images", value=False) @@ -425,7 +423,21 @@ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_ if __name__ == "__main__": parser = argparse.ArgumentParser(description='Run the OmniGen') parser.add_argument('--share', action='store_true', help='Share the Gradio app') + parser.add_argument('-b', '--nbits', choices=['4','8'], help='bitsandbytes quantization n-bits') args = parser.parse_args() + + quantization_config = None + model = None + if args.nbits: + if args.nbits == '4': + quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4') + elif args.nbits == '8': + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + + model = OmniGen.from_pretrained("Shitao/OmniGen-v1", quantization_config=quantization_config) + + pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1", model=model) + # launch demo.launch(share=args.share) diff --git a/requirements.txt b/requirements.txt index 358c613..0600e8e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ pillow==10.2.0 peft==0.13.2 diffusers==0.30.3 timm==0.9.16 +bitsandbytes==0.44.1 \ No newline at end of file