diff --git a/modules/dataLoader/Flux2BaseDataLoader.py b/modules/dataLoader/Flux2BaseDataLoader.py new file mode 100644 index 000000000..d44c7ce58 --- /dev/null +++ b/modules/dataLoader/Flux2BaseDataLoader.py @@ -0,0 +1,147 @@ +import os + +from modules.dataLoader.BaseDataLoader import BaseDataLoader +from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin +from modules.model.Flux2Model import HIDDEN_STATES_LAYERS, SYSTEM_MESSAGE, Flux2Model +from modules.modelSetup.BaseFlux2Setup import BaseFlux2Setup +from modules.util import factory +from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.TrainProgress import TrainProgress + +from mgds.pipelineModules.DecodeTokens import DecodeTokens +from mgds.pipelineModules.DecodeVAE import DecodeVAE +from mgds.pipelineModules.EncodeMistralText import EncodeMistralText +from mgds.pipelineModules.EncodeVAE import EncodeVAE +from mgds.pipelineModules.RescaleImageChannels import RescaleImageChannels +from mgds.pipelineModules.SampleVAEDistribution import SampleVAEDistribution +from mgds.pipelineModules.SaveImage import SaveImage +from mgds.pipelineModules.SaveText import SaveText +from mgds.pipelineModules.ScaleImage import ScaleImage +from mgds.pipelineModules.Tokenize import Tokenize + +from diffusers.pipelines.flux2.pipeline_flux2 import format_input + + +class Flux2BaseDataLoader( #TODO share code + BaseDataLoader, + DataLoaderText2ImageMixin, +): + def _preparation_modules(self, config: TrainConfig, model: Flux2Model): + rescale_image = RescaleImageChannels(image_in_name='image', image_out_name='image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) + encode_image = EncodeVAE(in_name='image', out_name='latent_image_distribution', vae=model.vae, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) + image_sample = SampleVAEDistribution(in_name='latent_image_distribution', out_name='latent_image', mode='mean') + downscale_mask = ScaleImage(in_name='mask', out_name='latent_mask', factor=0.125) + tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=config.text_encoder_sequence_length, + apply_chat_template = lambda caption: format_input([caption], SYSTEM_MESSAGE), apply_chat_template_kwargs = {'add_generation_prompt': False}, + ) + encode_prompt = EncodeMistralText(tokens_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', tokens_attention_mask_out_name='tokens_mask', + text_encoder=model.text_encoder, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype(), + hidden_state_output_index=HIDDEN_STATES_LAYERS, + ) + + modules = [rescale_image, encode_image, image_sample] + if config.masked_training or config.model_type.has_mask_input(): + modules.append(downscale_mask) + + modules += [tokenize_prompt, encode_prompt] + return modules + + def _cache_modules(self, config: TrainConfig, model: Flux2Model, model_setup: BaseFlux2Setup): + image_split_names = ['latent_image', 'original_resolution', 'crop_offset'] + + if config.masked_training or config.model_type.has_mask_input(): + image_split_names.append('latent_mask') + + image_aggregate_names = ['crop_resolution', 'image_path'] + + text_split_names = [] + + sort_names = image_aggregate_names + image_split_names + [ + 'prompt', 'tokens', 'tokens_mask', 'text_encoder_hidden_state', + 'concept' + ] + + text_split_names += ['tokens', 'tokens_mask', 'text_encoder_hidden_state'] + + return self._cache_modules_from_names( + model, model_setup, + image_split_names=image_split_names, + image_aggregate_names=image_aggregate_names, + text_split_names=text_split_names, + sort_names=sort_names, + config=config, + text_caching=True, + ) + + def _output_modules(self, config: TrainConfig, model: Flux2Model, model_setup: BaseFlux2Setup): + output_names = [ + 'image_path', 'latent_image', + 'prompt', + 'tokens', + 'tokens_mask', + 'original_resolution', 'crop_resolution', 'crop_offset', + ] + + if config.masked_training or config.model_type.has_mask_input(): + output_names.append('latent_mask') + + output_names.append('text_encoder_hidden_state') + + return self._output_modules_from_out_names( + model, model_setup, + output_names=output_names, + config=config, + use_conditioning_image=False, + vae=model.vae, + autocast_context=[model.autocast_context], + train_dtype=model.train_dtype, + ) + + def _debug_modules(self, config: TrainConfig, model: Flux2Model): + debug_dir = os.path.join(config.debug_dir, "dataloader") + + def before_save_fun(): + model.vae_to(self.train_device) + + decode_image = DecodeVAE(in_name='latent_image', out_name='decoded_image', vae=model.vae, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) + upscale_mask = ScaleImage(in_name='latent_mask', out_name='decoded_mask', factor=8) + decode_prompt = DecodeTokens(in_name='tokens', out_name='decoded_prompt', tokenizer=model.tokenizer) + save_image = SaveImage(image_in_name='decoded_image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1, before_save_fun=before_save_fun) + # SaveImage(image_in_name='latent_mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1, before_save_fun=before_save_fun) + save_mask = SaveImage(image_in_name='decoded_mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1, before_save_fun=before_save_fun) + save_prompt = SaveText(text_in_name='decoded_prompt', original_path_in_name='image_path', path=debug_dir, before_save_fun=before_save_fun) + + # These modules don't really work, since they are inserted after a sorting operation that does not include this data + # SaveImage(image_in_name='mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1), + # SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), + + modules = [] + + modules.append(decode_image) + modules.append(save_image) + + if config.masked_training or config.model_type.has_mask_input(): + modules.append(upscale_mask) + modules.append(save_mask) + + modules.append(decode_prompt) + modules.append(save_prompt) + + return modules + + def _create_dataset( + self, + config: TrainConfig, + model: Flux2Model, + model_setup: BaseFlux2Setup, + train_progress: TrainProgress, + is_validation: bool = False, + ): + return DataLoaderText2ImageMixin._create_dataset(self, + config, model, model_setup, train_progress, is_validation, + aspect_bucketing_quantization=64, + ) + + +factory.register(BaseDataLoader, Flux2BaseDataLoader, ModelType.FLUX_DEV_2) diff --git a/modules/model/Flux2Model.py b/modules/model/Flux2Model.py new file mode 100644 index 000000000..83c58403b --- /dev/null +++ b/modules/model/Flux2Model.py @@ -0,0 +1,298 @@ +import math +from contextlib import nullcontext +from random import Random + +from modules.model.BaseModel import BaseModel +from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util.convert_util import add_prefix, lora_qkv_fusion, qkv_fusion, remove_prefix, swap_chunks +from modules.util.enum.ModelType import ModelType +from modules.util.LayerOffloadConductor import LayerOffloadConductor + +import torch +from torch import Tensor + +from diffusers import ( + AutoencoderKLFlux2, + DiffusionPipeline, + FlowMatchEulerDiscreteScheduler, + Flux2Pipeline, + Flux2Transformer2DModel, +) +from diffusers.pipelines.flux2.pipeline_flux2 import format_input +from transformers import AutoProcessor, Mistral3ForConditionalGeneration + +SYSTEM_MESSAGE = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." +HIDDEN_STATES_LAYERS = [10, 20, 30] + +def diffusers_to_original(qkv_fusion): + return [ + ("context_embedder", "txt_in"), + ("x_embedder", "img_in"), + ("time_guidance_embed.timestep_embedder", "time_in", [ + ("linear_1", "in_layer"), + ("linear_2", "out_layer"), + ]), + ("time_guidance_embed.guidance_embedder", "guidance_in", [ + ("linear_1", "in_layer"), + ("linear_2", "out_layer"), + ]), + ("double_stream_modulation_img.linear", "double_stream_modulation_img.lin"), + ("double_stream_modulation_txt.linear", "double_stream_modulation_txt.lin"), + ("single_stream_modulation.linear", "single_stream_modulation.lin"), + ("proj_out", "final_layer.linear"), + ("norm_out.linear", "final_layer.adaLN_modulation.1", swap_chunks, swap_chunks), + ("transformer_blocks.{i}", "double_blocks.{i}", + qkv_fusion("attn.to_q", "attn.to_k", "attn.to_v", "img_attn.qkv") + \ + qkv_fusion("attn.add_q_proj", "attn.add_k_proj", "attn.add_v_proj", "txt_attn.qkv") + [ + ("attn.norm_k.weight", "img_attn.norm.key_norm.scale"), + ("attn.norm_q.weight", "img_attn.norm.query_norm.scale"), + ("attn.to_out.0", "img_attn.proj"), + ("ff.linear_in", "img_mlp.0"), + ("ff.linear_out", "img_mlp.2"), + ("attn.norm_added_k.weight", "txt_attn.norm.key_norm.scale"), + ("attn.norm_added_q.weight", "txt_attn.norm.query_norm.scale"), + ("attn.to_add_out", "txt_attn.proj"), + ("ff_context.linear_in", "txt_mlp.0"), + ("ff_context.linear_out", "txt_mlp.2"), + ]), + ("single_transformer_blocks.{i}", "single_blocks.{i}", [ + ("attn.to_qkv_mlp_proj", "linear1"), + ("attn.to_out", "linear2"), + ("attn.norm_k.weight", "norm.key_norm.scale"), + ("attn.norm_q.weight", "norm.query_norm.scale"), + ]), + ] + +diffusers_lora_to_original = diffusers_to_original(lora_qkv_fusion) +diffusers_checkpoint_to_original = diffusers_to_original(qkv_fusion) +diffusers_lora_to_comfy = [remove_prefix("transformer"), diffusers_to_original(lora_qkv_fusion), add_prefix("diffusion_model")] + + +class Flux2Model(BaseModel): + # base model data + tokenizer: AutoProcessor | None + noise_scheduler: FlowMatchEulerDiscreteScheduler | None + text_encoder: Mistral3ForConditionalGeneration | None + vae: AutoencoderKLFlux2 | None + transformer: Flux2Transformer2DModel | None + + # autocast context + text_encoder_autocast_context: torch.autocast | nullcontext + + text_encoder_offload_conductor: LayerOffloadConductor | None + transformer_offload_conductor: LayerOffloadConductor | None + + transformer_lora: LoRAModuleWrapper | None + lora_state_dict: dict | None + + def __init__( + self, + model_type: ModelType, + ): + super().__init__( + model_type=model_type, + ) + + self.tokenizer = None + self.noise_scheduler = None + self.text_encoder = None + self.vae = None + self.transformer = None + + self.text_encoder_autocast_context = nullcontext() + + self.text_encoder_offload_conductor = None + self.transformer_offload_conductor = None + + self.transformer_lora = None + self.lora_state_dict = None + + def adapters(self) -> list[LoRAModuleWrapper]: + return [a for a in [ + self.transformer_lora, + ] if a is not None] + + def vae_to(self, device: torch.device): + self.vae.to(device=device) + + def text_encoder_to(self, device: torch.device): + if self.text_encoder is not None: + if self.text_encoder_offload_conductor is not None and \ + self.text_encoder_offload_conductor.layer_offload_activated(): + self.text_encoder_offload_conductor.to(device) + else: + self.text_encoder.to(device=device) + + def transformer_to(self, device: torch.device): + if self.transformer_offload_conductor is not None and \ + self.transformer_offload_conductor.layer_offload_activated(): + self.transformer_offload_conductor.to(device) + else: + self.transformer.to(device=device) + + if self.transformer_lora is not None: + self.transformer_lora.to(device) + + def to(self, device: torch.device): + self.vae_to(device) + self.text_encoder_to(device) + self.transformer_to(device) + + def eval(self): + self.vae.eval() + if self.text_encoder is not None: + self.text_encoder.eval() + self.transformer.eval() + + def create_pipeline(self) -> DiffusionPipeline: + return Flux2Pipeline( + transformer=self.transformer, + scheduler=self.noise_scheduler, + vae=self.vae, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + ) + + def encode_text( + self, + train_device: torch.device, + batch_size: int = 1, + rand: Random | None = None, + text: str = None, + tokens: Tensor = None, + tokens_mask: Tensor = None, + text_encoder_sequence_length: int | None = None, + text_encoder_dropout_probability: float | None = None, + text_encoder_output: Tensor = None, + ) -> tuple[Tensor, Tensor]: + if tokens is None and text is not None: + if isinstance(text, str): + text = [text] + + messages = format_input(prompts=text, system_message=SYSTEM_MESSAGE) + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ) + + tokenizer_output = self.tokenizer( + text, + max_length=text_encoder_sequence_length, #max length is including system message + padding='max_length', + truncation=True, + return_tensors="pt" + ) + tokens = tokenizer_output.input_ids.to(self.text_encoder.device) + tokens_mask = tokenizer_output.attention_mask.to(self.text_encoder.device) + + if text_encoder_output is None and self.text_encoder is not None: + with self.text_encoder_autocast_context: + text_encoder_output = self.text_encoder( + tokens, + attention_mask=tokens_mask.float(), + output_hidden_states=True, + use_cache=False, + ) + + text_encoder_output = torch.stack([text_encoder_output.hidden_states[k] for k in HIDDEN_STATES_LAYERS], dim=1) + batch_size, num_channels, seq_len, hidden_dim = text_encoder_output.shape + assert seq_len == text_encoder_sequence_length + text_encoder_output = text_encoder_output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + if text_encoder_dropout_probability is not None and text_encoder_dropout_probability > 0.0: + raise NotImplementedError #https://github.com/Nerogar/OneTrainer/issues/957 + + return text_encoder_output + + + #code adapted from https://github.com/huggingface/diffusers/blob/c8656ed73c638e51fc2e777a5fd355d69fa5220f/src/diffusers/pipelines/flux2/pipeline_flux2.py + @staticmethod + def prepare_latent_image_ids(latents: torch.Tensor) -> torch.Tensor: + batch_size, _, height, width = latents.shape + + t = torch.arange(1, device=latents.device) + h = torch.arange(height, device=latents.device) + w = torch.arange(width, device=latents.device) + l_ = torch.arange(1, device=latents.device) + + latent_ids = torch.cartesian_prod(t, h, w, l_) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + #packing and unpacking on patchified latents + @staticmethod + def pack_latents(latents) -> Tensor: + batch_size, num_channels, height, width = latents.shape + return latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + @staticmethod + def unpack_latents(latents, height: int, width: int) -> Tensor: + batch_size, seq_len, num_channels = latents.shape + return latents.reshape(batch_size, height, width, num_channels).permute(0, 3, 1, 2) + + #TODO inference code uses empirical mu. But that code cannot be used for inference because it depends on num of inference steps + # is dynamic timestep shifting during training still applicable? + #unpatchified width and height + def calculate_timestep_shift(self, latent_height: int, latent_width: int) -> float: + base_seq_len = self.noise_scheduler.config.base_image_seq_len + max_seq_len = self.noise_scheduler.config.max_image_seq_len + base_shift = self.noise_scheduler.config.base_shift + max_shift = self.noise_scheduler.config.max_shift + patch_size = 2 + + image_seq_len = (latent_width // patch_size) * (latent_height // patch_size) + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return math.exp(mu) + + @staticmethod + def prepare_text_ids(x: torch.Tensor) -> torch.Tensor: + B, L, _ = x.shape + out_ids = [] + + for _ in range(B): #TODO why iterate? can text ids have different length? according to diffusers and original inference code: no + t = torch.arange(1, device=x.device) + h = torch.arange(1, device=x.device) + w = torch.arange(1, device=x.device) + l_ = torch.arange(L, device=x.device) + + coords = torch.cartesian_prod(t, h, w, l_) + out_ids.append(coords) + + return torch.stack(out_ids) + + @staticmethod + def patchify_latents(latents: torch.Tensor) -> torch.Tensor: + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + def unpatchify_latents(latents: torch.Tensor) -> torch.Tensor: + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + #scaling on patchified latents + def scale_latents(self, latents: Tensor) -> Tensor: + #TODO moves to device - necessary? save in model? + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + return (latents - latents_bn_mean) / latents_bn_std + + + def unscale_latents(self, latents: Tensor) -> Tensor: + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + return latents * latents_bn_std + latents_bn_mean diff --git a/modules/model/FluxModel.py b/modules/model/FluxModel.py index c9ec81c88..4b85f1272 100644 --- a/modules/model/FluxModel.py +++ b/modules/model/FluxModel.py @@ -341,7 +341,7 @@ def unpack_latents(self, latents, height: int, width: int): return latents - def calculate_timestep_shift(self, latent_width: int, latent_height: int): + def calculate_timestep_shift(self, latent_height: int, latent_width: int): base_seq_len = self.noise_scheduler.config.base_image_seq_len max_seq_len = self.noise_scheduler.config.max_image_seq_len base_shift = self.noise_scheduler.config.base_shift diff --git a/modules/model/WuerstchenModel.py b/modules/model/WuerstchenModel.py index 1404663e0..b470cf019 100644 --- a/modules/model/WuerstchenModel.py +++ b/modules/model/WuerstchenModel.py @@ -170,6 +170,9 @@ def to(self, device: torch.device): self.prior_text_encoder_to(device) self.prior_prior_to(device) + def vae_to(self, device: torch.device): + raise NotImplementedError + def eval(self): if self.model_type.is_wuerstchen_v2(): self.decoder_text_encoder.eval() diff --git a/modules/modelLoader/Flux2ModelLoader.py b/modules/modelLoader/Flux2ModelLoader.py new file mode 100644 index 000000000..6329b1f95 --- /dev/null +++ b/modules/modelLoader/Flux2ModelLoader.py @@ -0,0 +1,227 @@ +import os +import traceback + +from modules.model.BaseModel import BaseModel +from modules.model.Flux2Model import Flux2Model +from modules.modelLoader.GenericFineTuneModelLoader import make_fine_tune_model_loader +from modules.modelLoader.GenericLoRAModelLoader import make_lora_model_loader +from modules.modelLoader.mixin.HFModelLoaderMixin import HFModelLoaderMixin +from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin +from modules.util.config.TrainConfig import QuantizationConfig + +#from omi_model_standards.convert.lora.convert_flux_lora import convert_flux_lora_key_sets #TODO +from modules.util.convert.lora.convert_lora_util import LoraConversionKeySet +from modules.util.enum.ModelType import ModelType +from modules.util.ModelNames import ModelNames +from modules.util.ModelWeightDtypes import ModelWeightDtypes + +import torch + +from diffusers import ( + AutoencoderKLFlux2, + FlowMatchEulerDiscreteScheduler, + Flux2Transformer2DModel, + GGUFQuantizationConfig, +) +from transformers import ( + Mistral3ForConditionalGeneration, + PixtralProcessor, +) + + +class Flux2ModelLoader( + HFModelLoaderMixin, +): + def __init__(self): + super().__init__() + + def __load_internal( + self, + model: Flux2Model, + model_type: ModelType, + weight_dtypes: ModelWeightDtypes, + base_model_name: str, + transformer_model_name: str, + vae_model_name: str, + quantization: QuantizationConfig, + ): + if os.path.isfile(os.path.join(base_model_name, "meta.json")): + self.__load_diffusers( + model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name, quantization, + ) + else: + raise Exception("not an internal model") + + def __load_diffusers( + self, + model: Flux2Model, + model_type: ModelType, + weight_dtypes: ModelWeightDtypes, + base_model_name: str, + transformer_model_name: str, + vae_model_name: str, + quantization: QuantizationConfig, + ): + diffusers_sub = [] + transformers_sub = ["text_encoder"] + if not transformer_model_name: + diffusers_sub.append("transformer") + if not vae_model_name: + diffusers_sub.append("vae") + + self._prepare_sub_modules( + base_model_name, + diffusers_modules=diffusers_sub, + transformers_modules=transformers_sub, + ) + + tokenizer = PixtralProcessor.from_pretrained( + base_model_name, + subfolder="tokenizer", + ).tokenizer + + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + base_model_name, + subfolder="scheduler", + ) + + text_encoder = self._load_transformers_sub_module( + Mistral3ForConditionalGeneration, + weight_dtypes.text_encoder, + weight_dtypes.fallback_train_dtype, + base_model_name, + "text_encoder", + ) + + if vae_model_name: + vae = self._load_diffusers_sub_module( + AutoencoderKLFlux2, + weight_dtypes.vae, + weight_dtypes.train_dtype, + vae_model_name, + ) + else: + vae = self._load_diffusers_sub_module( + AutoencoderKLFlux2, + weight_dtypes.vae, + weight_dtypes.train_dtype, + base_model_name, + "vae", + ) + + if transformer_model_name: + transformer = Flux2Transformer2DModel.from_single_file( + transformer_model_name, + #avoid loading the transformer in float32: + torch_dtype = torch.bfloat16 if weight_dtypes.transformer.torch_dtype() is None else weight_dtypes.transformer.torch_dtype(), + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer.is_gguf() else None, + ) + transformer = self._convert_diffusers_sub_module_to_dtype( + transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quantization, + ) + else: + transformer = self._load_diffusers_sub_module( + Flux2Transformer2DModel, + weight_dtypes.transformer, + weight_dtypes.train_dtype, + base_model_name, + "transformer", + quantization, + ) + + model.model_type = model_type + model.tokenizer = tokenizer + model.noise_scheduler = noise_scheduler + model.text_encoder = text_encoder + model.vae = vae + model.transformer = transformer + + def __load_safetensors( + self, + model: Flux2Model, + model_type: ModelType, + weight_dtypes: ModelWeightDtypes, + base_model_name: str, + transformer_model_name: str, + vae_model_name: str, + quantization: QuantizationConfig, + ): + #no single file .safetensors for Qwen available at the time of writing this code + raise NotImplementedError("Loading of single file Flux2 models not supported. Use the diffusers model instead. Optionally, transformer-only safetensor files can be loaded by overriding the transformer.") + + def load( + self, + model: Flux2Model, + model_type: ModelType, + model_names: ModelNames, + weight_dtypes: ModelWeightDtypes, + quantization: QuantizationConfig, + ): + stacktraces = [] + + try: + self.__load_internal( + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quantization, + ) + return + except Exception: + stacktraces.append(traceback.format_exc()) + + try: + self.__load_diffusers( + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quantization, + ) + return + except Exception: + stacktraces.append(traceback.format_exc()) + + try: + self.__load_safetensors( + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quantization, + ) + return + except Exception: + stacktraces.append(traceback.format_exc()) + + for stacktrace in stacktraces: + print(stacktrace) + raise Exception("could not load model: " + model_names.base_model) + + + +class Flux2LoRALoader( + LoRALoaderMixin +): + def __init__(self): + super().__init__() + + def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: + return None #TODO + #return convert_flux_lora_key_sets() + + def load( + self, + model: Flux2Model, + model_names: ModelNames, + ): + return self._load(model, model_names) + + +Flux2LoRAModelLoader = make_lora_model_loader( + model_spec_map={ + ModelType.FLUX_DEV_2: "resources/sd_model_spec/flux_dev_2.0-lora.json", + }, + model_class=Flux2Model, + model_loader_class=Flux2ModelLoader, + lora_loader_class=Flux2LoRALoader, + embedding_loader_class=None, +) + +Flux2FineTuneModelLoader = make_fine_tune_model_loader( + model_spec_map={ + ModelType.FLUX_DEV_2: "resources/sd_model_spec/flux_dev_2.0.json", + }, + model_class=Flux2Model, + model_loader_class=Flux2ModelLoader, + embedding_loader_class=None, +) diff --git a/modules/modelSampler/Flux2Sampler.py b/modules/modelSampler/Flux2Sampler.py new file mode 100644 index 000000000..33075e2fd --- /dev/null +++ b/modules/modelSampler/Flux2Sampler.py @@ -0,0 +1,190 @@ +import copy +import inspect +from collections.abc import Callable + +from modules.model.Flux2Model import Flux2Model +from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput +from modules.util import factory +from modules.util.config.SampleConfig import SampleConfig +from modules.util.enum.AudioFormat import AudioFormat +from modules.util.enum.FileType import FileType +from modules.util.enum.ImageFormat import ImageFormat +from modules.util.enum.ModelType import ModelType +from modules.util.enum.NoiseScheduler import NoiseScheduler +from modules.util.enum.VideoFormat import VideoFormat +from modules.util.torch_util import torch_gc + +import torch + +from diffusers.pipelines.flux2.pipeline_flux2 import compute_empirical_mu + +import numpy as np +from tqdm import tqdm + + +class Flux2Sampler(BaseModelSampler): + def __init__( + self, + train_device: torch.device, + temp_device: torch.device, + model: Flux2Model, + model_type: ModelType, + ): + super().__init__(train_device, temp_device) + + self.model = model + self.model_type = model_type + self.pipeline = model.create_pipeline() + + @torch.no_grad() + def __sample_base( + self, + prompt: str, + height: int, + width: int, + seed: int, + random_seed: bool, + diffusion_steps: int, + cfg_scale: float, + noise_scheduler: NoiseScheduler, + text_encoder_sequence_length: int | None = None, + on_update_progress: Callable[[int, int], None] = lambda _, __: None, + ) -> ModelSamplerOutput: + with self.model.autocast_context: + generator = torch.Generator(device=self.train_device) + if random_seed: + generator.seed() + else: + generator.manual_seed(seed) + + noise_scheduler = copy.deepcopy(self.model.noise_scheduler) + image_processor = self.pipeline.image_processor + transformer = self.pipeline.transformer + vae = self.pipeline.vae + + vae_scale_factor = 8 + num_latent_channels = 32 + patch_size = 2 + + # prepare prompt + self.model.text_encoder_to(self.train_device) + + prompt_embedding = self.model.encode_text( + text=prompt, + train_device=self.train_device, + text_encoder_sequence_length=text_encoder_sequence_length, + ) + + self.model.text_encoder_to(self.temp_device) + torch_gc() + + # prepare latent image + latent_image = torch.randn( + size=(1, num_latent_channels, height // vae_scale_factor, width // vae_scale_factor), + generator=generator, + device=self.train_device, + dtype=torch.float32, + ) + + latent_image = self.model.patchify_latents(latent_image) + image_ids = self.model.prepare_latent_image_ids(latent_image) + + #TODO test dynamic timestep shifting instead of empirical + #shift = self.model.calculate_timestep_shift(latent_image.shape[-2], latent_image.shape[-1]) + #mu = math.log(shift) + + latent_image = self.model.pack_latents(latent_image) + image_seq_len = latent_image.shape[1] + mu = compute_empirical_mu(image_seq_len, diffusion_steps) + + # prepare timesteps + #TODO for other models, too? This is different than with sigmas=None + sigmas = np.linspace(1.0, 1 / diffusion_steps, diffusion_steps) + noise_scheduler.set_timesteps(diffusion_steps, device=self.train_device, mu=mu, sigmas=sigmas) + timesteps = noise_scheduler.timesteps + + # denoising loop + extra_step_kwargs = {} #TODO remove + if "generator" in set(inspect.signature(noise_scheduler.step).parameters.keys()): + extra_step_kwargs["generator"] = generator + + text_ids = self.model.prepare_text_ids(prompt_embedding) + + + self.model.transformer_to(self.train_device) + for i, timestep in enumerate(tqdm(timesteps, desc="sampling")): + latent_model_input = torch.cat([latent_image]) + expanded_timestep = timestep.expand(latent_model_input.shape[0]) + + guidance = torch.tensor([cfg_scale], device=self.train_device) + + noise_pred = transformer( + hidden_states=latent_model_input.to(dtype=self.model.train_dtype.torch_dtype()), + timestep=expanded_timestep / 1000, + guidance=guidance.to(dtype=self.model.train_dtype.torch_dtype()), + encoder_hidden_states=prompt_embedding.to(dtype=self.model.train_dtype.torch_dtype()), + txt_ids=text_ids, + img_ids=image_ids, + joint_attention_kwargs=None, + return_dict=True + ).sample + + latent_image = noise_scheduler.step(noise_pred, timestep, latent_image, return_dict=False, **extra_step_kwargs)[0] + + on_update_progress(i + 1, len(timesteps)) + + self.model.transformer_to(self.temp_device) + torch_gc() + self.model.vae_to(self.train_device) + + latent_image = self.model.unpack_latents( + latent_image, + height // vae_scale_factor // patch_size, + width // vae_scale_factor // patch_size, + ) + latents = self.model.unscale_latents(latent_image) + latents = self.model.unpatchify_latents(latents) + + image = vae.decode(latents, return_dict=False)[0] + + image = image_processor.postprocess(image, output_type='pil') + + self.model.vae_to(self.temp_device) + torch_gc() + + return ModelSamplerOutput( + file_type=FileType.IMAGE, + data=image[0], + ) + + def sample( + self, + sample_config: SampleConfig, + destination: str, + image_format: ImageFormat | None = None, + video_format: VideoFormat | None = None, + audio_format: AudioFormat | None = None, + on_sample: Callable[[ModelSamplerOutput], None] = lambda _: None, + on_update_progress: Callable[[int, int], None] = lambda _, __: None, + ): + sampler_output = self.__sample_base( + prompt=sample_config.prompt, + height=self.quantize_resolution(sample_config.height, 64), + width=self.quantize_resolution(sample_config.width, 64), + seed=sample_config.seed, + random_seed=sample_config.random_seed, + diffusion_steps=sample_config.diffusion_steps, + cfg_scale=sample_config.cfg_scale, + noise_scheduler=sample_config.noise_scheduler, + text_encoder_sequence_length=sample_config.text_encoder_1_sequence_length, + on_update_progress=on_update_progress, + ) + + self.save_sampler_output( + sampler_output, destination, + image_format, video_format, audio_format, + ) + + on_sample(sampler_output) + +factory.register(BaseModelSampler, Flux2Sampler, ModelType.FLUX_DEV_2) diff --git a/modules/modelSampler/FluxSampler.py b/modules/modelSampler/FluxSampler.py index a0e73b593..93f8837ba 100644 --- a/modules/modelSampler/FluxSampler.py +++ b/modules/modelSampler/FluxSampler.py @@ -147,7 +147,6 @@ def __sample_base( self.model.transformer_to(self.temp_device) torch_gc() - latent_image = self.model.unpack_latents( latent_image, height // vae_scale_factor, @@ -160,7 +159,7 @@ def __sample_base( latents = (latent_image / vae.config.scaling_factor) + vae.config.shift_factor image = vae.decode(latents, return_dict=False)[0] - do_denormalize = [True] * image.shape[0] + do_denormalize = [True] * image.shape[0] #TODO remove and test, from Flux and other models. True is the default image = image_processor.postprocess(image, output_type='pil', do_denormalize=do_denormalize) self.model.vae_to(self.temp_device) diff --git a/modules/modelSaver/Flux2FineTuneModelSaver.py b/modules/modelSaver/Flux2FineTuneModelSaver.py new file mode 100644 index 000000000..4d86200da --- /dev/null +++ b/modules/modelSaver/Flux2FineTuneModelSaver.py @@ -0,0 +1,11 @@ +from modules.model.Flux2Model import Flux2Model +from modules.modelSaver.flux2.Flux2ModelSaver import Flux2ModelSaver +from modules.modelSaver.GenericFineTuneModelSaver import make_fine_tune_model_saver +from modules.util.enum.ModelType import ModelType + +Flux2FineTuneModelSaver = make_fine_tune_model_saver( + ModelType.FLUX_DEV_2, + model_class=Flux2Model, + model_saver_class=Flux2ModelSaver, + embedding_saver_class=None, +) diff --git a/modules/modelSaver/Flux2LoRAModelSaver.py b/modules/modelSaver/Flux2LoRAModelSaver.py new file mode 100644 index 000000000..7a3cbaf3c --- /dev/null +++ b/modules/modelSaver/Flux2LoRAModelSaver.py @@ -0,0 +1,11 @@ +from modules.model.Flux2Model import Flux2Model +from modules.modelSaver.flux2.Flux2LoRASaver import Flux2LoRASaver +from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver +from modules.util.enum.ModelType import ModelType + +Flux2LoRAModelSaver = make_lora_model_saver( + ModelType.FLUX_DEV_2, + model_class=Flux2Model, + lora_saver_class=Flux2LoRASaver, + embedding_saver_class=None, +) diff --git a/modules/modelSaver/flux2/Flux2LoRASaver.py b/modules/modelSaver/flux2/Flux2LoRASaver.py new file mode 100644 index 000000000..15471a82c --- /dev/null +++ b/modules/modelSaver/flux2/Flux2LoRASaver.py @@ -0,0 +1,52 @@ +import os +from pathlib import Path + +from modules.model.Flux2Model import Flux2Model, diffusers_lora_to_comfy +from modules.modelSaver.mixin.LoRASaverMixin import LoRASaverMixin +from modules.util.convert.lora.convert_lora_util import LoraConversionKeySet +from modules.util.convert_util import convert +from modules.util.enum.ModelFormat import ModelFormat + +import torch +from torch import Tensor + +from safetensors.torch import save_file + + +class Flux2LoRASaver( + LoRASaverMixin, +): + def __init__(self): + super().__init__() + + def _get_convert_key_sets(self, model: Flux2Model) -> list[LoraConversionKeySet] | None: + return None + + def _get_state_dict( + self, + model: Flux2Model, + ) -> dict[str, Tensor]: + state_dict = {} + if model.transformer_lora is not None: + state_dict |= model.transformer_lora.state_dict() + if model.lora_state_dict is not None: + state_dict |= model.lora_state_dict + + return state_dict + + def save( + self, + model: Flux2Model, + output_model_format: ModelFormat, + output_model_destination: str, + dtype: torch.dtype | None, + ): + if output_model_format == ModelFormat.COMFY_LORA: + state_dict = self._get_state_dict(model) + save_state_dict = self._convert_state_dict_dtype(state_dict, dtype) + save_state_dict = convert(save_state_dict, diffusers_lora_to_comfy) + + os.makedirs(Path(output_model_destination).parent.absolute(), exist_ok=True) + save_file(save_state_dict, output_model_destination, self._create_safetensors_header(model, save_state_dict)) + else: + self._save(model, output_model_format, output_model_destination, dtype) diff --git a/modules/modelSaver/flux2/Flux2ModelSaver.py b/modules/modelSaver/flux2/Flux2ModelSaver.py new file mode 100644 index 000000000..e2976244e --- /dev/null +++ b/modules/modelSaver/flux2/Flux2ModelSaver.py @@ -0,0 +1,85 @@ +import copy +import os.path +from pathlib import Path + +from modules.model.Flux2Model import Flux2Model, diffusers_checkpoint_to_original +from modules.modelSaver.mixin.DtypeModelSaverMixin import DtypeModelSaverMixin +from modules.util.convert_util import convert +from modules.util.enum.ModelFormat import ModelFormat + +import torch + +from safetensors.torch import save_file + + +class Flux2ModelSaver( + DtypeModelSaverMixin, +): + def __init__(self): + super().__init__() + + def __save_diffusers( + self, + model: Flux2Model, + destination: str, + dtype: torch.dtype | None, + ): + # Copy the model to cpu by first moving the original model to cpu. This preserves some VRAM. + pipeline = model.create_pipeline() + pipeline.to("cpu") + if dtype is not None: #TODO necessary? + # replace the tokenizers __deepcopy__ before calling deepcopy, to prevent a copy being made. + # the tokenizer tries to reload from the file system otherwise + tokenizer = pipeline.tokenizer + tokenizer.__deepcopy__ = lambda memo: tokenizer + + save_pipeline = copy.deepcopy(pipeline) + save_pipeline.to(device="cpu", dtype=dtype, silence_dtype_warnings=True) + + delattr(tokenizer, '__deepcopy__') + else: + save_pipeline = pipeline + + os.makedirs(Path(destination).absolute(), exist_ok=True) + save_pipeline.save_pretrained(destination) + + if dtype is not None: + del save_pipeline + + def __save_safetensors( + self, + model: Flux2Model, + destination: str, + dtype: torch.dtype | None, + ): + state_dict = model.transformer.state_dict() + state_dict = convert(state_dict, diffusers_checkpoint_to_original) + + save_state_dict = self._convert_state_dict_dtype(state_dict, dtype) + self._convert_state_dict_to_contiguous(save_state_dict) + + os.makedirs(Path(destination).parent.absolute(), exist_ok=True) + + save_file(save_state_dict, destination, self._create_safetensors_header(model, save_state_dict)) + + def __save_internal( + self, + model: Flux2Model, + destination: str, + ): + self.__save_diffusers(model, destination, None) + + def save( + self, + model: Flux2Model, + output_model_format: ModelFormat, + output_model_destination: str, + dtype: torch.dtype | None, + ): + match output_model_format: + case ModelFormat.DIFFUSERS: + self.__save_diffusers(model, output_model_destination, dtype) + case ModelFormat.SAFETENSORS: + self.__save_safetensors(model, output_model_destination, dtype) + case ModelFormat.INTERNAL: + self.__save_internal(model, output_model_destination) diff --git a/modules/modelSetup/BaseFlux2Setup.py b/modules/modelSetup/BaseFlux2Setup.py new file mode 100644 index 000000000..ecd64efe7 --- /dev/null +++ b/modules/modelSetup/BaseFlux2Setup.py @@ -0,0 +1,198 @@ +from abc import ABCMeta +from random import Random + +import modules.util.multi_gpu_util as multi +from modules.model.Flux2Model import Flux2Model +from modules.model.FluxModel import FluxModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.modelSetup.mixin.ModelSetupDebugMixin import ModelSetupDebugMixin +from modules.modelSetup.mixin.ModelSetupDiffusionLossMixin import ModelSetupDiffusionLossMixin +from modules.modelSetup.mixin.ModelSetupEmbeddingMixin import ModelSetupEmbeddingMixin +from modules.modelSetup.mixin.ModelSetupFlowMatchingMixin import ModelSetupFlowMatchingMixin +from modules.modelSetup.mixin.ModelSetupNoiseMixin import ModelSetupNoiseMixin +from modules.util.checkpointing_util import ( + enable_checkpointing_for_flux2_transformer, + enable_checkpointing_for_mistral_encoder_layers, +) +from modules.util.config.TrainConfig import TrainConfig +from modules.util.dtype_util import create_autocast_context, disable_fp16_autocast_context +from modules.util.enum.TrainingMethod import TrainingMethod +from modules.util.quantization_util import quantize_layers +from modules.util.torch_util import torch_gc +from modules.util.TrainProgress import TrainProgress + +import torch +from torch import Tensor + + +class BaseFlux2Setup( + BaseModelSetup, + ModelSetupDiffusionLossMixin, + ModelSetupDebugMixin, + ModelSetupNoiseMixin, + ModelSetupFlowMatchingMixin, + ModelSetupEmbeddingMixin, + metaclass=ABCMeta +): + LAYER_PRESETS = { + "blocks": ["transformer_block"], + "full": [], + } + + def setup_optimizations( + self, + model: Flux2Model, + config: TrainConfig, + ): + if config.gradient_checkpointing.enabled(): + model.transformer_offload_conductor = \ + enable_checkpointing_for_flux2_transformer(model.transformer, config) + if model.text_encoder is not None: + model.text_encoder_offload_conductor = \ + enable_checkpointing_for_mistral_encoder_layers(model.text_encoder, config) + + if config.force_circular_padding: + raise NotImplementedError #TODO applies to Flux2? +# apply_circular_padding_to_conv2d(model.vae) +# apply_circular_padding_to_conv2d(model.transformer) +# if model.transformer_lora is not None: +# apply_circular_padding_to_conv2d(model.transformer_lora) + + model.autocast_context, model.train_dtype = create_autocast_context(self.train_device, config.train_dtype, [ + config.weight_dtypes().transformer, + config.weight_dtypes().text_encoder, + config.weight_dtypes().vae, + config.weight_dtypes().lora if config.training_method == TrainingMethod.LORA else None, + ], config.enable_autocast_cache) + + model.text_encoder_autocast_context, model.text_encoder_train_dtype = \ + disable_fp16_autocast_context( + self.train_device, + config.train_dtype, + config.fallback_train_dtype, + [ + config.weight_dtypes().text_encoder, + config.weight_dtypes().lora if config.training_method == TrainingMethod.LORA else None, + ], + config.enable_autocast_cache, + ) + + quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.train_dtype, config) + + def predict( + self, + model: Flux2Model, + batch: dict, + config: TrainConfig, + train_progress: TrainProgress, + *, + deterministic: bool = False, + ) -> dict: + with model.autocast_context: + batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank() + generator = torch.Generator(device=config.train_device) + generator.manual_seed(batch_seed) + rand = Random(batch_seed) + + text_encoder_output = model.encode_text( + train_device=self.train_device, + batch_size=batch['latent_image'].shape[0], + rand=rand, + tokens=batch.get("tokens"), + tokens_mask=batch.get("tokens_mask"), + text_encoder_sequence_length=config.text_encoder_sequence_length, + text_encoder_output=batch.get('text_encoder_hidden_state'), + text_encoder_dropout_probability=config.text_encoder.dropout_probability, + ) + latent_image = model.patchify_latents(batch['latent_image'].float()) + latent_height = latent_image.shape[-2] + latent_width = latent_image.shape[-1] + scaled_latent_image = model.scale_latents(latent_image) + + latent_noise = self._create_noise(scaled_latent_image, config, generator) + + shift = model.calculate_timestep_shift(latent_height, latent_width) + timestep = self._get_timestep_discrete( + model.noise_scheduler.config['num_train_timesteps'], + deterministic, + generator, + scaled_latent_image.shape[0], + config, + shift = shift if config.dynamic_timestep_shifting else config.timestep_shift, + ) + + scaled_noisy_latent_image, sigma = self._add_noise_discrete( + scaled_latent_image, + latent_noise, + timestep, + model.noise_scheduler.timesteps, + ) + latent_input = scaled_noisy_latent_image + + guidance = torch.tensor([config.transformer.guidance_scale], device=self.train_device) + guidance = guidance.expand(latent_input.shape[0]) + + text_ids = model.prepare_text_ids(text_encoder_output) + image_ids = model.prepare_latent_image_ids(latent_input) + packed_latent_input = model.pack_latents(latent_input) + + packed_predicted_flow = model.transformer( + hidden_states=packed_latent_input.to(dtype=model.train_dtype.torch_dtype()), + timestep=timestep / 1000, + guidance=guidance.to(dtype=model.train_dtype.torch_dtype()), + encoder_hidden_states=text_encoder_output.to(dtype=model.train_dtype.torch_dtype()), + txt_ids=text_ids, + img_ids=image_ids, + joint_attention_kwargs=None, + return_dict=True + ).sample + + predicted_flow = model.unpack_latents( + packed_predicted_flow, + latent_input.shape[2], + latent_input.shape[3], + ) + + flow = latent_noise - scaled_latent_image + model_output_data = { + 'loss_type': 'target', + 'timestep': timestep, + 'predicted': predicted_flow, + 'target': flow, + } + + if config.debug_mode: + with torch.no_grad(): + predicted_scaled_latent_image = scaled_noisy_latent_image - predicted_flow * sigma + self._save_tokens("7-prompt", batch['tokens'], model.tokenizer, config, train_progress) + self._save_latent("1-noise", latent_noise, config, train_progress) + self._save_latent("2-noisy_image", scaled_noisy_latent_image, config, train_progress) + self._save_latent("3-predicted_flow", predicted_flow, config, train_progress) + self._save_latent("4-flow", flow, config, train_progress) + self._save_latent("5-predicted_image", predicted_scaled_latent_image, config, train_progress) + self._save_latent("6-image", scaled_latent_image, config, train_progress) + + return model_output_data + + def calculate_loss( + self, + model: Flux2Model, + batch: dict, + data: dict, + config: TrainConfig, + ) -> Tensor: + return self._flow_matching_losses( + batch=batch, + data=data, + config=config, + train_device=self.train_device, + sigmas=model.noise_scheduler.sigmas, + ).mean() + + def prepare_text_caching(self, model: FluxModel, config: TrainConfig): + model.to(self.temp_device) + model.text_encoder_to(self.train_device) + model.eval() + torch_gc() diff --git a/modules/modelSetup/Flux2FineTuneSetup.py b/modules/modelSetup/Flux2FineTuneSetup.py new file mode 100644 index 000000000..06b0ff2ad --- /dev/null +++ b/modules/modelSetup/Flux2FineTuneSetup.py @@ -0,0 +1,88 @@ +from modules.model.Flux2Model import Flux2Model +from modules.modelSetup.BaseFlux2Setup import BaseFlux2Setup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util import factory +from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod +from modules.util.ModuleFilter import ModuleFilter +from modules.util.NamedParameterGroup import NamedParameterGroupCollection +from modules.util.optimizer_util import init_model_parameters +from modules.util.TrainProgress import TrainProgress + +import torch + + +class Flux2FineTuneSetup( + BaseFlux2Setup, +): + def __init__( + self, + train_device: torch.device, + temp_device: torch.device, + debug_mode: bool, + ): + super().__init__( + train_device=train_device, + temp_device=temp_device, + debug_mode=debug_mode, + ) + + def create_parameters( + self, + model: Flux2Model, + config: TrainConfig, + ) -> NamedParameterGroupCollection: + parameter_group_collection = NamedParameterGroupCollection() + + self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.transformer, + freeze=ModuleFilter.create(config), debug=config.debug_mode) + return parameter_group_collection + + def __setup_requires_grad( + self, + model: Flux2Model, + config: TrainConfig, + ): + self._setup_model_part_requires_grad("transformer", model.transformer, config.transformer, model.train_progress) + model.vae.requires_grad_(False) + model.text_encoder.requires_grad_(False) + + + def setup_model( + self, + model: Flux2Model, + config: TrainConfig, + ): + self.__setup_requires_grad(model, config) + init_model_parameters(model, self.create_parameters(model, config), self.train_device) + + def setup_train_device( + self, + model: Flux2Model, + config: TrainConfig, + ): + vae_on_train_device = not config.latent_caching + text_encoder_on_train_device = not config.latent_caching + + model.text_encoder_to(self.train_device if text_encoder_on_train_device else self.temp_device) + model.vae_to(self.train_device if vae_on_train_device else self.temp_device) + model.transformer_to(self.train_device) + + model.text_encoder.eval() + model.vae.eval() + + if config.transformer.train: + model.transformer.train() + else: + model.transformer.eval() + + def after_optimizer_step( + self, + model: Flux2Model, + config: TrainConfig, + train_progress: TrainProgress + ): + self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, Flux2FineTuneSetup, ModelType.FLUX_DEV_2, TrainingMethod.FINE_TUNE) diff --git a/modules/modelSetup/Flux2LoRASetup.py b/modules/modelSetup/Flux2LoRASetup.py new file mode 100644 index 000000000..d9f9b428c --- /dev/null +++ b/modules/modelSetup/Flux2LoRASetup.py @@ -0,0 +1,101 @@ +from modules.model.Flux2Model import Flux2Model +from modules.modelSetup.BaseFlux2Setup import BaseFlux2Setup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util import factory +from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod +from modules.util.NamedParameterGroup import NamedParameterGroupCollection +from modules.util.optimizer_util import init_model_parameters +from modules.util.TrainProgress import TrainProgress + +import torch + + +class Flux2LoRASetup( + BaseFlux2Setup, +): + def __init__( + self, + train_device: torch.device, + temp_device: torch.device, + debug_mode: bool, + ): + super().__init__( + train_device=train_device, + temp_device=temp_device, + debug_mode=debug_mode, + ) + + def create_parameters( + self, + model: Flux2Model, + config: TrainConfig, + ) -> NamedParameterGroupCollection: + parameter_group_collection = NamedParameterGroupCollection() + + self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.transformer) + return parameter_group_collection + + def __setup_requires_grad( + self, + model: Flux2Model, + config: TrainConfig, + ): + model.text_encoder.requires_grad_(False) + model.transformer.requires_grad_(False) + model.vae.requires_grad_(False) + + self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.transformer, model.train_progress) + + def setup_model( + self, + model: Flux2Model, + config: TrainConfig, + ): + model.transformer_lora = LoRAModuleWrapper( + model.transformer, "lora_transformer", config, config.layer_filter.split(",") + ) + + if model.lora_state_dict: + model.transformer_lora.load_state_dict(model.lora_state_dict) + model.lora_state_dict = None + + model.transformer_lora.set_dropout(config.dropout_probability) + model.transformer_lora.to(dtype=config.lora_weight_dtype.torch_dtype()) + model.transformer_lora.hook_to_module() + + self.__setup_requires_grad(model, config) + + init_model_parameters(model, self.create_parameters(model, config), self.train_device) + + def setup_train_device( + self, + model: Flux2Model, + config: TrainConfig, + ): + vae_on_train_device = not config.latent_caching + text_encoder_on_train_device = not config.latent_caching + + model.text_encoder_to(self.train_device if text_encoder_on_train_device else self.temp_device) + model.vae_to(self.train_device if vae_on_train_device else self.temp_device) + model.transformer_to(self.train_device) + + model.text_encoder.eval() + model.vae.eval() + + if config.transformer.train: + model.transformer.train() + else: + model.transformer.eval() + + def after_optimizer_step( + self, + model: Flux2Model, + config: TrainConfig, + train_progress: TrainProgress + ): + self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, Flux2LoRASetup, ModelType.FLUX_DEV_2, TrainingMethod.LORA) diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index b73745879..3da6342ca 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -164,8 +164,8 @@ def torch_backward(a, b): run_benchmark(lambda: torch_backward(y_8, w_8), "torch mm backward int8") run_benchmark(lambda: mm_8bit(y_8, w_8), "triton mm backward int8") - run_benchmark(lambda: int8_forward_tokenwise(x, w_8, w_scale), "torch forward int", compile=True) - run_benchmark(lambda: int8_backward_axiswise(y, w_8, w_scale), "triton backward int", compile=True) + run_benchmark(lambda: int8_forward_tokenwise(x, w_8, w_scale, bias=None, compute_dtype=torch.bfloat16), "torch forward int", compile=True) + run_benchmark(lambda: int8_backward_axiswise(y, w_8, w_scale, bias=None, compute_dtype=torch.bfloat16), "triton backward int", compile=True) @torch.no_grad() diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index 1e336ab2b..7d52d1748 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -5,6 +5,7 @@ from modules.util.enum.ConfigPart import ConfigPart from modules.util.enum.DataType import DataType from modules.util.enum.ModelFormat import ModelFormat +from modules.util.enum.ModelType import PeftType from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ui import components from modules.util.ui.UIState import UIState @@ -55,8 +56,10 @@ def refresh_ui(self): self.__setup_wuerstchen_ui(base_frame) elif self.train_config.model_type.is_pixart(): self.__setup_pixart_alpha_ui(base_frame) - elif self.train_config.model_type.is_flux(): + elif self.train_config.model_type.is_flux_1(): self.__setup_flux_ui(base_frame) + elif self.train_config.model_type.is_flux_2(): + self.__setup_flux_2_ui(base_frame) elif self.train_config.model_type.is_z_image(): self.__setup_z_image_ui(base_frame) elif self.train_config.model_type.is_chroma(): @@ -131,6 +134,26 @@ def __setup_flux_ui(self, frame): allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) + def __setup_flux_2_ui(self, frame): + row = 0 + row = self.__create_base_dtype_components(frame, row) + row = self.__create_base_components( + frame, + row, + has_transformer=True, + allow_override_transformer=True, + has_text_encoder_1=True, + has_vae=True, + ) + row = self.__create_output_components( + frame, + row, + allow_safetensors=True, + allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, + allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, + allow_comfy=self.train_config.training_method == TrainingMethod.LORA and self.train_config.peft_type == PeftType.LORA, + ) + def __setup_z_image_ui(self, frame): row = 0 row = self.__create_base_dtype_components(frame, row) @@ -590,6 +613,7 @@ def __create_output_components( allow_safetensors: bool = False, allow_diffusers: bool = False, allow_legacy_safetensors: bool = False, + allow_comfy: bool = False, ) -> int: # output model destination components.label(frame, row, 0, "Model Output Destination", @@ -617,6 +641,8 @@ def __create_output_components( formats.append(("Diffusers", ModelFormat.DIFFUSERS)) # if allow_legacy_safetensors: # formats.append(("Legacy Safetensors", ModelFormat.LEGACY_SAFETENSORS)) + if allow_comfy: + formats.append(("Comfy", ModelFormat.COMFY_LORA)) components.label(frame, row, 0, "Output Format", tooltip="Format to use when saving the output model") diff --git a/modules/ui/TopBar.py b/modules/ui/TopBar.py index c53ea4160..00c211114 100644 --- a/modules/ui/TopBar.py +++ b/modules/ui/TopBar.py @@ -92,8 +92,9 @@ def __init__( ("Stable Cascade", ModelType.STABLE_CASCADE_1), ("PixArt Alpha", ModelType.PIXART_ALPHA), ("PixArt Sigma", ModelType.PIXART_SIGMA), - ("Flux Dev", ModelType.FLUX_DEV_1), + ("Flux Dev.1", ModelType.FLUX_DEV_1), ("Flux Fill Dev", ModelType.FLUX_FILL_DEV_1), + ("Flux Dev.2", ModelType.FLUX_DEV_2), ("Sana", ModelType.SANA), ("Hunyuan Video", ModelType.HUNYUAN_VIDEO), ("HiDream Full", ModelType.HI_DREAM_FULL), diff --git a/modules/ui/TrainingTab.py b/modules/ui/TrainingTab.py index 931ba8039..2c564e8d3 100644 --- a/modules/ui/TrainingTab.py +++ b/modules/ui/TrainingTab.py @@ -69,8 +69,10 @@ def refresh_ui(self): self.__setup_wuerstchen_ui(column_0, column_1, column_2) elif self.train_config.model_type.is_pixart(): self.__setup_pixart_alpha_ui(column_0, column_1, column_2) - elif self.train_config.model_type.is_flux(): + elif self.train_config.model_type.is_flux_1(): self.__setup_flux_ui(column_0, column_1, column_2) + elif self.train_config.model_type.is_flux_2(): + self.__setup_flux_2_ui(column_0, column_1, column_2) elif self.train_config.model_type.is_chroma(): self.__setup_chroma_ui(column_0, column_1, column_2) elif self.train_config.model_type.is_qwen(): @@ -167,6 +169,18 @@ def __setup_flux_ui(self, column_0, column_1, column_2): self.__create_loss_frame(column_2, 2) self.__create_layer_frame(column_2, 3) + def __setup_flux_2_ui(self, column_0, column_1, column_2): + self.__create_base_frame(column_0, 0) + self.__create_text_encoder_frame(column_0, 1, supports_clip_skip=False, supports_training=False, supports_sequence_length=True) + + self.__create_base2_frame(column_1, 0) + self.__create_transformer_frame(column_1, 1, supports_guidance_scale=True, supports_force_attention_mask=False) + self.__create_noise_frame(column_1, 2, supports_dynamic_timestep_shifting=True) + + self.__create_masked_frame(column_2, 1) + self.__create_loss_frame(column_2, 2) + self.__create_layer_frame(column_2, 3) + def __setup_chroma_ui(self, column_0, column_1, column_2): self.__create_base_frame(column_0, 0) self.__create_text_encoder_frame(column_0, 1) @@ -400,12 +414,11 @@ def __create_base2_frame(self, master, row, video_training_enabled: bool = False tooltip="Enables circular padding for all conv layers to better train seamless images") components.switch(frame, row, 1, self.ui_state, "force_circular_padding") - def __create_text_encoder_frame(self, master, row, supports_clip_skip=True, supports_training=True): + def __create_text_encoder_frame(self, master, row, supports_clip_skip=True, supports_training=True, supports_sequence_length=False): frame = ctk.CTkFrame(master=master, corner_radius=5) frame.grid(row=row, column=0, padx=5, pady=5, sticky="nsew") frame.grid_columnconfigure(0, weight=1) - # train text encoder if supports_training: components.label(frame, 0, 0, "Train Text Encoder", tooltip="Enables training the text encoder model") @@ -434,6 +447,13 @@ def __create_text_encoder_frame(self, master, row, supports_clip_skip=True, supp tooltip="The number of additional clip layers to skip. 0 = the model default") components.entry(frame, 4, 1, self.ui_state, "text_encoder_layer_skip") + if supports_sequence_length: + # text encoder sequence length + components.label(frame, row, 0, "Text Encoder Sequence Length", + tooltip="Number of tokens for captions") + components.entry(frame, row, 1, self.ui_state, "text_encoder_sequence_length") + row += 1 + def __create_text_encoder_n_frame( self, master, diff --git a/modules/util/checkpointing_util.py b/modules/util/checkpointing_util.py index 133f97cf0..6ca75d8a3 100644 --- a/modules/util/checkpointing_util.py +++ b/modules/util/checkpointing_util.py @@ -25,6 +25,7 @@ from transformers.models.clip.modeling_clip import CLIPEncoderLayer from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLDecoderLayer from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer from transformers.models.t5.modeling_t5 import T5Block @@ -111,7 +112,6 @@ def __init__(self, orig_module: nn.Module, orig_forward, train_device: torch.dev self.layer_index = layer_index def __checkpointing_forward(self, dummy: torch.Tensor, call_id: int, *args): - if self.layer_index == 0 and not torch.is_grad_enabled(): self.conductor.start_forward(True) @@ -131,7 +131,6 @@ def __checkpointing_forward(self, dummy: torch.Tensor, call_id: int, *args): def forward(self, *args, **kwargs): call_id = _generate_call_index() args = _kwargs_to_args(self.orig_forward if self.checkpoint is None else self.checkpoint.forward, args, kwargs) - if torch.is_grad_enabled(): return torch.utils.checkpoint.checkpoint( self.__checkpointing_forward, @@ -306,6 +305,16 @@ def enable_checkpointing_for_llama_encoder_layers( (LlamaDecoderLayer, []), ]) +def enable_checkpointing_for_mistral_encoder_layers( + model: nn.Module, + config: TrainConfig, +) -> LayerOffloadConductor: + return enable_checkpointing(model, config, False, [ + (MistralDecoderLayer, []), + ]) + + + def enable_checkpointing_for_qwen_encoder_layers( model: nn.Module, config: TrainConfig, @@ -339,6 +348,15 @@ def enable_checkpointing_for_flux_transformer( (model.single_transformer_blocks, ["hidden_states" ]), ]) +def enable_checkpointing_for_flux2_transformer( + model: nn.Module, + config: TrainConfig, +) -> LayerOffloadConductor: + return enable_checkpointing(model, config, config.compile, [ + (model.transformer_blocks, ["hidden_states", "encoder_hidden_states"]), + (model.single_transformer_blocks, ["hidden_states" ]), + ]) + def enable_checkpointing_for_chroma_transformer( model: nn.Module, diff --git a/modules/util/config/SampleConfig.py b/modules/util/config/SampleConfig.py index 4b72345f8..1b2aba652 100644 --- a/modules/util/config/SampleConfig.py +++ b/modules/util/config/SampleConfig.py @@ -19,6 +19,7 @@ class SampleConfig(BaseConfig): noise_scheduler: NoiseScheduler text_encoder_1_layer_skip: int + text_encoder_1_sequence_length: int | None text_encoder_2_layer_skip: int text_encoder_2_sequence_length: int | None text_encoder_3_layer_skip: int @@ -35,6 +36,7 @@ def __init__(self, data: list[(str, Any, type, bool)]): def from_train_config(self, train_config): self.text_encoder_1_layer_skip = train_config.text_encoder_layer_skip + self.text_encoder_1_sequence_length = train_config.text_encoder_sequence_length self.text_encoder_2_layer_skip = train_config.text_encoder_2_layer_skip self.text_encoder_2_sequence_length = train_config.text_encoder_2_sequence_length self.text_encoder_3_layer_skip = train_config.text_encoder_3_layer_skip @@ -60,6 +62,7 @@ def default_values(): data.append(("noise_scheduler", NoiseScheduler.DDIM, NoiseScheduler, False)) data.append(("text_encoder_1_layer_skip", 0, int, False)) + data.append(("text_encoder_1_sequence_length", None, int, True)) data.append(("text_encoder_2_layer_skip", 0, int, False)) data.append(("text_encoder_2_sequence_length", None, int, True)) data.append(("text_encoder_3_layer_skip", 0, int, False)) diff --git a/modules/util/config/TrainConfig.py b/modules/util/config/TrainConfig.py index ddaee4b89..3a24a10df 100644 --- a/modules/util/config/TrainConfig.py +++ b/modules/util/config/TrainConfig.py @@ -273,7 +273,7 @@ class TrainModelPartConfig(BaseConfig): stop_training_after_unit: TimeUnit learning_rate: float weight_dtype: DataType - dropout_probability: float + dropout_probability: float #this is text encoder caption dropout! train_embedding: bool attention_mask: bool guidance_scale: float @@ -430,7 +430,7 @@ class TrainConfig(BaseConfig): vb_loss_strength: float loss_weight_fn: LossWeight loss_weight_strength: float - dropout_probability: float + dropout_probability: float #this is LoRA dropout! loss_scaler: LossScaler learning_rate_scaler: LearningRateScaler clip_grad_norm: float @@ -872,6 +872,12 @@ def train_text_encoder_4_or_embedding(self) -> bool: or ((self.text_encoder_4.train_embedding or not self.model_type.has_multiple_text_encoders()) and self.train_any_embedding()) + def train_any_text_encoder_or_embedding(self) -> bool: + return (self.train_text_encoder_or_embedding() + or self.train_text_encoder_2_or_embedding() + or self.train_text_encoder_3_or_embedding() + or self.train_text_encoder_4_or_embedding()) + def all_embedding_configs(self): if self.training_method == TrainingMethod.EMBEDDING: return self.additional_embeddings + [self.embedding] @@ -1069,6 +1075,7 @@ def default_values() -> 'TrainConfig': text_encoder.learning_rate = None data.append(("text_encoder", text_encoder, TrainModelPartConfig, False)) data.append(("text_encoder_layer_skip", 0, int, False)) + data.append(("text_encoder_sequence_length", 512, int, True)) # text encoder 2 text_encoder_2 = TrainModelPartConfig.default_values() diff --git a/modules/util/enum/ModelFormat.py b/modules/util/enum/ModelFormat.py index 597ad4442..70193a61b 100644 --- a/modules/util/enum/ModelFormat.py +++ b/modules/util/enum/ModelFormat.py @@ -6,6 +6,7 @@ class ModelFormat(Enum): CKPT = 'CKPT' SAFETENSORS = 'SAFETENSORS' LEGACY_SAFETENSORS = 'LEGACY_SAFETENSORS' + COMFY_LORA = 'COMFY_LORA' INTERNAL = 'INTERNAL' # an internal format that stores all information to resume training @@ -23,6 +24,8 @@ def file_extension(self) -> str: return '.safetensors' case ModelFormat.LEGACY_SAFETENSORS: return '.safetensors' + case ModelFormat.COMFY_LORA: + return '.safetensors' case _: return '' diff --git a/modules/util/enum/ModelType.py b/modules/util/enum/ModelType.py index bb8740e97..8dfebd0ff 100644 --- a/modules/util/enum/ModelType.py +++ b/modules/util/enum/ModelType.py @@ -25,6 +25,7 @@ class ModelType(Enum): FLUX_DEV_1 = 'FLUX_DEV_1' FLUX_FILL_DEV_1 = 'FLUX_FILL_DEV_1' + FLUX_DEV_2 = 'FLUX_DEV_2' SANA = 'SANA' @@ -77,9 +78,17 @@ def is_pixart_sigma(self): return self == ModelType.PIXART_SIGMA def is_flux(self): + return self == ModelType.FLUX_DEV_1 \ + or self == ModelType.FLUX_FILL_DEV_1 \ + or self == ModelType.FLUX_DEV_2 + + def is_flux_1(self): return self == ModelType.FLUX_DEV_1 \ or self == ModelType.FLUX_FILL_DEV_1 + def is_flux_2(self): + return self == ModelType.FLUX_DEV_2 + def is_chroma(self): return self == ModelType.CHROMA_1 @@ -116,7 +125,7 @@ def has_depth_input(self): def has_multiple_text_encoders(self): return self.is_stable_diffusion_3() \ or self.is_stable_diffusion_xl() \ - or self.is_flux() \ + or self.is_flux_1() \ or self.is_hunyuan_video() \ or self.is_hi_dream() \ diff --git a/requirements-global.txt b/requirements-global.txt index 561d7195d..399946e30 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -21,7 +21,8 @@ pytorch-lightning==2.5.1.post0 # diffusion models #Note: check whether Qwen bugs in diffusers have been fixed before upgrading diffusers (see BaseQwenSetup): --e git+https://github.com/huggingface/diffusers.git@256e010#egg=diffusers +#-e git+https://github.com/huggingface/diffusers.git@256e010#egg=diffusers +-e git+https://github.com/dxqb/diffusers.git@flux2_tuples#egg=diffusers gguf==0.17.1 transformers==4.56.2 sentencepiece==0.2.1 # transitive dependency of transformers for tokenizer loading @@ -33,7 +34,7 @@ pooch==1.8.2 open-clip-torch==2.32.0 # data loader --e git+https://github.com/Nerogar/mgds.git@385578f#egg=mgds +-e git+https://github.com/dxqb/mgds.git@flux2#egg=mgds # optimizers dadaptation==3.2 # dadaptation optimizers diff --git a/resources/sd_model_spec/flux_dev_2.0-lora.json b/resources/sd_model_spec/flux_dev_2.0-lora.json new file mode 100644 index 000000000..03c0aed03 --- /dev/null +++ b/resources/sd_model_spec/flux_dev_2.0-lora.json @@ -0,0 +1,6 @@ +{ + "modelspec.sai_model_spec": "1.0.0", + "modelspec.architecture": "Flux.2-dev/lora", + "modelspec.implementation": "https://github.com/huggingface/diffusers", + "modelspec.title": "FluxDev 2.0 LoRA" +} diff --git a/resources/sd_model_spec/flux_dev_2.0.json b/resources/sd_model_spec/flux_dev_2.0.json new file mode 100644 index 000000000..c743cc9e8 --- /dev/null +++ b/resources/sd_model_spec/flux_dev_2.0.json @@ -0,0 +1,6 @@ +{ + "modelspec.sai_model_spec": "1.0.0", + "modelspec.architecture": "Flux.2-dev", + "modelspec.implementation": "https://github.com/huggingface/diffusers", + "modelspec.title": "FluxDev 2.0" +}