From 1e8fecedddfd05ea33751d4f56e0eeef75372691 Mon Sep 17 00:00:00 2001 From: dxqb <183307934+dxqb@users.noreply.github.com> Date: Sat, 4 Oct 2025 17:42:12 +0200 Subject: [PATCH 01/54] Compile int svd --- .../modelLoader/mixin/HFModelLoaderMixin.py | 19 +- modules/modelSetup/BaseChromaSetup.py | 9 +- modules/modelSetup/BaseFluxSetup.py | 10 +- modules/modelSetup/BaseHiDreamSetup.py | 14 +- modules/modelSetup/BaseHunyuanVideoSetup.py | 10 +- modules/modelSetup/BasePixArtAlphaSetup.py | 8 +- modules/modelSetup/BaseQwenSetup.py | 8 +- modules/modelSetup/BaseSanaSetup.py | 8 +- .../modelSetup/BaseStableDiffusion3Setup.py | 12 +- .../modelSetup/BaseStableDiffusionSetup.py | 8 +- .../modelSetup/BaseStableDiffusionXLSetup.py | 10 +- modules/modelSetup/BaseWuerstchenSetup.py | 12 +- .../mixin/ModelSetupDiffusionLossMixin.py | 9 +- modules/module/quantized/LinearFp8.py | 2 +- modules/module/quantized/LinearNf4.py | 2 +- modules/module/quantized/LinearSVD.py | 79 +++ modules/module/quantized/LinearW8A8.py | 223 +++++++ modules/trainer/GenericTrainer.py | 17 +- modules/ui/ModelTab.py | 39 +- modules/ui/TrainUI.py | 1 + modules/util/checkpointing_util.py | 555 ++++++------------ modules/util/config/TrainConfig.py | 6 + modules/util/enum/DataType.py | 30 +- modules/util/memory_util.py | 24 - modules/util/profiling_util.py | 53 ++ modules/util/quantization_util.py | 117 ++-- modules/util/triton_mm_8bit.py | 121 ++++ requirements-cuda.txt | 6 +- requirements-global.txt | 2 +- requirements-rocm.txt | 3 + 30 files changed, 857 insertions(+), 560 deletions(-) create mode 100644 modules/module/quantized/LinearSVD.py create mode 100644 modules/module/quantized/LinearW8A8.py delete mode 100644 modules/util/memory_util.py create mode 100644 modules/util/profiling_util.py create mode 100644 modules/util/triton_mm_8bit.py diff --git a/modules/modelLoader/mixin/HFModelLoaderMixin.py b/modules/modelLoader/mixin/HFModelLoaderMixin.py index 163702417..6570fc3e9 100644 --- a/modules/modelLoader/mixin/HFModelLoaderMixin.py +++ b/modules/modelLoader/mixin/HFModelLoaderMixin.py @@ -8,9 +8,7 @@ from modules.util.enum.DataType import DataType from modules.util.quantization_util import ( is_quantized_parameter, - replace_linear_with_fp8_layers, - replace_linear_with_int8_layers, - replace_linear_with_nf4_layers, + replace_linear_with_quantized_layers, ) import torch @@ -42,12 +40,7 @@ def __load_sub_module( keep_in_fp32_modules = [] with accelerate.init_empty_weights(): - if dtype.quantize_nf4(): - replace_linear_with_nf4_layers(sub_module, keep_in_fp32_modules, copy_parameters=False) - elif dtype.quantize_int8(): - replace_linear_with_int8_layers(sub_module, keep_in_fp32_modules, copy_parameters=False) - elif dtype.quantize_fp8(): - replace_linear_with_fp8_layers(sub_module, keep_in_fp32_modules, copy_parameters=False) + replace_linear_with_quantized_layers(sub_module, dtype, keep_in_fp32_modules, copy_parameters=False) is_local = os.path.isdir(pretrained_model_name_or_path) @@ -247,17 +240,11 @@ def __convert_sub_module_to_dtype( if keep_in_fp32_modules is None: keep_in_fp32_modules = [] - if dtype.quantize_nf4(): - replace_linear_with_nf4_layers(sub_module, keep_in_fp32_modules, copy_parameters=True) - elif dtype.quantize_int8(): - replace_linear_with_int8_layers(sub_module, keep_in_fp32_modules, copy_parameters=True) - elif dtype.quantize_fp8(): - replace_linear_with_fp8_layers(sub_module, keep_in_fp32_modules, copy_parameters=True) + replace_linear_with_quantized_layers(sub_module, dtype, keep_in_fp32_modules, copy_parameters=True) for module_name, module in sub_module.named_modules(): param_iter = [(x, y[0], y[1]) for x, y in zip(repeat(False), module._parameters.items(), strict=False)] buffer_iter = [(x, y[0], y[1]) for x, y in zip(repeat(True), module._buffers.items(), strict=False)] - for is_buffer, tensor_name, value in param_iter + buffer_iter: if value is not None and torch.is_floating_point(value): old_type = type(value) diff --git a/modules/modelSetup/BaseChromaSetup.py b/modules/modelSetup/BaseChromaSetup.py index 6a9342b41..96544e122 100644 --- a/modules/modelSetup/BaseChromaSetup.py +++ b/modules/modelSetup/BaseChromaSetup.py @@ -82,9 +82,9 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.transformer, self.train_device, model.train_dtype) + quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.train_dtype, config) def _setup_embeddings( self, @@ -225,7 +225,6 @@ def predict( ) packed_latent_input = model.pack_latents(latent_input) - image_seq_len = packed_latent_input.shape[1] image_attention_mask = torch.full((packed_latent_input.shape[0], image_seq_len), True, dtype=torch.bool, device=text_attention_mask.device) attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) @@ -329,5 +328,5 @@ def calculate_loss( data=data, config=config, train_device=self.train_device, - sigmas=model.noise_scheduler.sigmas.to(device=self.train_device), + sigmas=model.noise_scheduler.sigmas, ).mean() diff --git a/modules/modelSetup/BaseFluxSetup.py b/modules/modelSetup/BaseFluxSetup.py index b53f638f6..307894ec9 100644 --- a/modules/modelSetup/BaseFluxSetup.py +++ b/modules/modelSetup/BaseFluxSetup.py @@ -84,10 +84,10 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype) - quantize_layers(model.text_encoder_2, self.train_device, model.text_encoder_2_train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.transformer, self.train_device, model.train_dtype) + quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype, config) + quantize_layers(model.text_encoder_2, self.train_device, model.text_encoder_2_train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.train_dtype, config) def _setup_embeddings( self, @@ -387,5 +387,5 @@ def calculate_loss( data=data, config=config, train_device=self.train_device, - sigmas=model.noise_scheduler.sigmas.to(device=self.train_device), + sigmas=model.noise_scheduler.sigmas, ).mean() diff --git a/modules/modelSetup/BaseHiDreamSetup.py b/modules/modelSetup/BaseHiDreamSetup.py index a9a6da8af..91b71a2f3 100644 --- a/modules/modelSetup/BaseHiDreamSetup.py +++ b/modules/modelSetup/BaseHiDreamSetup.py @@ -97,12 +97,12 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype) - quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype) - quantize_layers(model.text_encoder_3, self.train_device, model.text_encoder_3_train_dtype) - quantize_layers(model.text_encoder_4, self.train_device, model.train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.transformer, self.train_device, model.transformer_train_dtype) + quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype, config) + quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype, config) + quantize_layers(model.text_encoder_3, self.train_device, model.text_encoder_3_train_dtype, config) + quantize_layers(model.text_encoder_4, self.train_device, model.train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.transformer_train_dtype, config) def _setup_embeddings( self, @@ -474,5 +474,5 @@ def calculate_loss( data=data, config=config, train_device=self.train_device, - sigmas=model.noise_scheduler.sigmas.to(device=self.train_device), + sigmas=model.noise_scheduler.sigmas, ).mean() diff --git a/modules/modelSetup/BaseHunyuanVideoSetup.py b/modules/modelSetup/BaseHunyuanVideoSetup.py index 786b89380..09c63f2da 100644 --- a/modules/modelSetup/BaseHunyuanVideoSetup.py +++ b/modules/modelSetup/BaseHunyuanVideoSetup.py @@ -84,10 +84,10 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype) - quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.transformer, self.train_device, model.transformer_train_dtype) + quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype, config) + quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.transformer_train_dtype, config) model.vae.enable_tiling() @@ -356,5 +356,5 @@ def calculate_loss( data=data, config=config, train_device=self.train_device, - sigmas=model.noise_scheduler.sigmas.to(device=self.train_device), + sigmas=model.noise_scheduler.sigmas , ).mean() diff --git a/modules/modelSetup/BasePixArtAlphaSetup.py b/modules/modelSetup/BasePixArtAlphaSetup.py index 3e8ad0a38..803e75831 100644 --- a/modules/modelSetup/BasePixArtAlphaSetup.py +++ b/modules/modelSetup/BasePixArtAlphaSetup.py @@ -81,9 +81,9 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.transformer, self.train_device, model.train_dtype) + quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.train_dtype, config) def _setup_embeddings( self, @@ -339,5 +339,5 @@ def calculate_loss( data=data, config=config, train_device=self.train_device, - betas=model.noise_scheduler.betas.to(device=self.train_device), + betas=model.noise_scheduler.betas, ).mean() diff --git a/modules/modelSetup/BaseQwenSetup.py b/modules/modelSetup/BaseQwenSetup.py index c8dcd4913..f17fa9c86 100644 --- a/modules/modelSetup/BaseQwenSetup.py +++ b/modules/modelSetup/BaseQwenSetup.py @@ -77,9 +77,9 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.transformer, self.train_device, model.train_dtype) + quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.train_dtype, config) def predict( self, @@ -250,5 +250,5 @@ def calculate_loss( data=data, config=config, train_device=self.train_device, - sigmas=model.noise_scheduler.sigmas.to(device=self.train_device), + sigmas=model.noise_scheduler.sigmas, ).mean() diff --git a/modules/modelSetup/BaseSanaSetup.py b/modules/modelSetup/BaseSanaSetup.py index 29b810774..18b06c14a 100644 --- a/modules/modelSetup/BaseSanaSetup.py +++ b/modules/modelSetup/BaseSanaSetup.py @@ -93,9 +93,9 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.transformer, self.train_device, model.train_dtype) + quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.train_dtype, config) def _setup_embeddings( self, @@ -312,5 +312,5 @@ def calculate_loss( data=data, config=config, train_device=self.train_device, - betas=model.noise_scheduler.betas.to(device=self.train_device), + betas=model.noise_scheduler.betas, ).mean() diff --git a/modules/modelSetup/BaseStableDiffusion3Setup.py b/modules/modelSetup/BaseStableDiffusion3Setup.py index 4f2888554..13738ca32 100644 --- a/modules/modelSetup/BaseStableDiffusion3Setup.py +++ b/modules/modelSetup/BaseStableDiffusion3Setup.py @@ -85,11 +85,11 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype) - quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype) - quantize_layers(model.text_encoder_3, self.train_device, model.text_encoder_3_train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.transformer, self.train_device, model.train_dtype) + quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype, config) + quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype, config) + quantize_layers(model.text_encoder_3, self.train_device, model.text_encoder_3_train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.train_dtype, config) def _setup_embeddings( self, @@ -409,5 +409,5 @@ def calculate_loss( data=data, config=config, train_device=self.train_device, - sigmas=model.noise_scheduler.sigmas.to(device=self.train_device), + sigmas=model.noise_scheduler.sigmas, ).mean() diff --git a/modules/modelSetup/BaseStableDiffusionSetup.py b/modules/modelSetup/BaseStableDiffusionSetup.py index 8878b0b82..0fc6ed0df 100644 --- a/modules/modelSetup/BaseStableDiffusionSetup.py +++ b/modules/modelSetup/BaseStableDiffusionSetup.py @@ -68,9 +68,9 @@ def setup_optimizations( config.weight_dtypes().embedding if config.train_any_embedding() else None, ], config.enable_autocast_cache) - quantize_layers(model.text_encoder, self.train_device, model.train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.unet, self.train_device, model.train_dtype) + quantize_layers(model.text_encoder, self.train_device, model.train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.unet, self.train_device, model.train_dtype, config) def _setup_embeddings( self, @@ -333,5 +333,5 @@ def calculate_loss( data=data, config=config, train_device=self.train_device, - betas=model.noise_scheduler.betas.to(device=self.train_device), + betas=model.noise_scheduler.betas, ).mean() diff --git a/modules/modelSetup/BaseStableDiffusionXLSetup.py b/modules/modelSetup/BaseStableDiffusionXLSetup.py index 3d7ee0c92..37121951a 100644 --- a/modules/modelSetup/BaseStableDiffusionXLSetup.py +++ b/modules/modelSetup/BaseStableDiffusionXLSetup.py @@ -76,10 +76,10 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype) - quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype) - quantize_layers(model.vae, self.train_device, model.vae_train_dtype) - quantize_layers(model.unet, self.train_device, model.train_dtype) + quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype, config) + quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype, config) + quantize_layers(model.vae, self.train_device, model.vae_train_dtype, config) + quantize_layers(model.unet, self.train_device, model.train_dtype, config) def _setup_embeddings( self, @@ -381,5 +381,5 @@ def calculate_loss( data=data, config=config, train_device=self.train_device, - betas=model.noise_scheduler.betas.to(device=self.train_device), + betas=model.noise_scheduler.betas, ).mean() diff --git a/modules/modelSetup/BaseWuerstchenSetup.py b/modules/modelSetup/BaseWuerstchenSetup.py index 973919de4..23b3440a5 100644 --- a/modules/modelSetup/BaseWuerstchenSetup.py +++ b/modules/modelSetup/BaseWuerstchenSetup.py @@ -104,12 +104,12 @@ def setup_optimizations( ) if model.model_type.is_wuerstchen_v2(): - quantize_layers(model.decoder_text_encoder, self.train_device, model.train_dtype) - quantize_layers(model.decoder_decoder, self.train_device, model.train_dtype) - quantize_layers(model.decoder_vqgan, self.train_device, model.train_dtype) - quantize_layers(model.effnet_encoder, self.train_device, model.effnet_encoder_train_dtype) - quantize_layers(model.prior_text_encoder, self.train_device, model.train_dtype) - quantize_layers(model.prior_prior, self.train_device, model.prior_train_dtype) + quantize_layers(model.decoder_text_encoder, self.train_device, model.train_dtype, config) + quantize_layers(model.decoder_decoder, self.train_device, model.train_dtype, config) + quantize_layers(model.decoder_vqgan, self.train_device, model.train_dtype, config) + quantize_layers(model.effnet_encoder, self.train_device, model.effnet_encoder_train_dtype, config) + quantize_layers(model.prior_text_encoder, self.train_device, model.train_dtype, config) + quantize_layers(model.prior_prior, self.train_device, model.prior_train_dtype, config) def _setup_embeddings( self, diff --git a/modules/modelSetup/mixin/ModelSetupDiffusionLossMixin.py b/modules/modelSetup/mixin/ModelSetupDiffusionLossMixin.py index 5c57275ac..f2ba67911 100644 --- a/modules/modelSetup/mixin/ModelSetupDiffusionLossMixin.py +++ b/modules/modelSetup/mixin/ModelSetupDiffusionLossMixin.py @@ -240,7 +240,7 @@ def _diffusion_losses( ) -> Tensor: loss_weight = batch['loss_weight'] if self.__coefficients is None and betas is not None: - self.__coefficients = DiffusionScheduleCoefficients.from_betas(betas) + self.__coefficients = DiffusionScheduleCoefficients.from_betas(betas.to(train_device)) self.__alphas_cumprod_fun = alphas_cumprod_fun @@ -255,7 +255,7 @@ def _diffusion_losses( # Scale Losses by Batch and/or GA (if enabled) losses = losses * config.loss_scaler.get_scale(batch_size=config.batch_size, accumulation_steps=config.gradient_accumulation_steps) - losses *= loss_weight.to(device=losses.device, dtype=losses.dtype) + losses *= loss_weight # Apply timestep based loss weighting. if 'timestep' in data: @@ -281,7 +281,7 @@ def _flow_matching_losses( loss_weight = batch['loss_weight'] if self.__sigmas is None and sigmas is not None: num_timesteps = sigmas.shape[0] - all_timesteps = torch.arange(start=1, end=num_timesteps + 1, step=1, dtype=torch.int32, device=sigmas.device) + all_timesteps = torch.arange(start=1, end=num_timesteps + 1, step=1, dtype=torch.int32, device=train_device) self.__sigmas = all_timesteps / num_timesteps if data['loss_type'] == 'target': @@ -294,8 +294,7 @@ def _flow_matching_losses( # Scale Losses by Batch and/or GA (if enabled) losses = losses * config.loss_scaler.get_scale(config.batch_size, config.gradient_accumulation_steps) - - losses *= loss_weight.to(device=losses.device, dtype=losses.dtype) + losses *= loss_weight # Apply timestep based loss weighting. if 'timestep' in data: diff --git a/modules/module/quantized/LinearFp8.py b/modules/module/quantized/LinearFp8.py index f0ad404d5..9c693b629 100644 --- a/modules/module/quantized/LinearFp8.py +++ b/modules/module/quantized/LinearFp8.py @@ -31,7 +31,7 @@ def unquantized_weight(self, dtype: torch.dtype, device: torch.device) -> torch. else: return self.weight.detach().to(dtype=dtype) - def quantize(self, device: torch.device | None = None): + def quantize(self, device: torch.device | None = None, **kwargs): if self.is_quantized: return self.is_quantized = True diff --git a/modules/module/quantized/LinearNf4.py b/modules/module/quantized/LinearNf4.py index 2a4bfbf17..2cbc46bc2 100644 --- a/modules/module/quantized/LinearNf4.py +++ b/modules/module/quantized/LinearNf4.py @@ -62,7 +62,7 @@ def unquantized_weight(self, dtype: torch.dtype, device: torch.device) -> torch else: return self.weight.detach().to(dtype=dtype) - def quantize(self, device: torch.device | None = None): + def quantize(self, device: torch.device | None = None, **kwargs): if self.is_quantized: return self.is_quantized = True diff --git a/modules/module/quantized/LinearSVD.py b/modules/module/quantized/LinearSVD.py new file mode 100644 index 000000000..5e612e8b5 --- /dev/null +++ b/modules/module/quantized/LinearSVD.py @@ -0,0 +1,79 @@ +import hashlib +from contextlib import suppress + +from modules.module.quantized.mixin.QuantizedLinearMixin import QuantizedLinearMixin +from modules.module.quantized.mixin.QuantizedModuleMixin import QuantizedModuleMixin + +import torch + + +class BaseLinearSVD( + QuantizedModuleMixin, + QuantizedLinearMixin, +): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +def _get_tensor_hash(t: torch.Tensor) -> str: + tensor = t.detach().cpu().contiguous() + tensor_bytes = tensor.numpy().tobytes() + hash_obj = hashlib.sha256(tensor_bytes) + return hash_obj.hexdigest() + +def make_svd_linear(linear_class): + class LinearSVD( + linear_class, + BaseLinearSVD, + ): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.register_buffer("svd_up", None) + self.register_buffer("svd_down", None) + + def unquantized_weight(self, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + if self.svd_up is None: + return super().unquantized_weight(dtype, device) + else: + return (self.svd_up @ self.svd_down).to(dtype) + super().unquantized_weight(dtype, device) + + def quantize(self, rank: int, svd_dtype: torch.dtype, device: torch.device | None = None, cache_dir: str | None = None, max_cache_rank: int = 128): + if self.svd_up is not None: + return + + W = super().unquantized_weight(torch.float32, device) + orig_device = W.device + if device is not None: + W = W.to(device=device) + + U = None + if cache_dir is not None: + filename = cache_dir + "/" + _get_tensor_hash(W) + ".pt" + with suppress(FileNotFoundError): + U, S, Vh = torch.load(filename, map_location=device) + + if U is None: + #use full svd - torch.svd_lowrank is not reducing the quant range nearly as much: + U, S, Vh = torch.linalg.svd(W, full_matrices=False) + + if cache_dir is not None: + torch.save(( + U[:, :max_cache_rank].clone(), + S[:max_cache_rank].clone(), + Vh[:max_cache_rank, :].clone(), + ), filename) + + U_r = U[:, :rank] + S_r = S[:rank] + Vh_r = Vh[:rank, :] + + self.svd_down = Vh_r.to(svd_dtype) + self.svd_up = (U_r * S_r.unsqueeze(0)).to(svd_dtype) + + self.weight.data = (W - (self.svd_up @ self.svd_down)).to(dtype=self.weight.dtype, device=orig_device) + super().quantize(device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return ((x @ self.svd_down.T) @ self.svd_up.T).to(x.dtype) + super().forward(x) + + return LinearSVD diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py new file mode 100644 index 000000000..623e391e8 --- /dev/null +++ b/modules/module/quantized/LinearW8A8.py @@ -0,0 +1,223 @@ +from modules.module.quantized.mixin.QuantizedLinearMixin import QuantizedLinearMixin +from modules.module.quantized.mixin.QuantizedModuleMixin import QuantizedModuleMixin +from modules.util.triton_mm_8bit import mm_8bit as triton_mm_8bit + +import torch +from torch import nn + + +def quantize_int8_tensorwise(x): + abs_max = x.abs().max() + scale = (abs_max.float() / 127.0).clamp(min=1e-12) + q = x.float().mul_(1.0 / scale).round_().clamp_(-128.0, 127.0).to(torch.int8) + return q, scale + + +def quantize_int8_channelwise(x, dim=-1): + abs_max = x.abs().amax(dim=dim, keepdim=True) + scale = (abs_max.float() / 127.0).clamp(min=1e-12) + q = x.float().mul_(1.0 / scale).round_().clamp_(-128.0, 127.0).to(torch.int8) + return q, scale + + +def quantize_fp8_tensorwise(x): + abs_max = x.abs().max() + scale = (abs_max.float() / 448.0).clamp(min=1e-12) + q = x.float().mul_(1.0 / scale).round().clamp(-448.0, 448.0).to(torch.float8_e4m3fn) + return q, scale + + +def quantize_fp8_channelwise(x, dim=-1): + abs_max = x.abs().amax(dim=dim, keepdim=True) + scale = (abs_max.float() / 448.0).clamp(min=1e-12) + q = x.float().mul_(1.0 / scale).round_().clamp_(-448.0, 448.0).to(torch.float8_e4m3fn) + return q, scale + + +def unquantize(q, scale, compute_dtype): + return q.to(compute_dtype).mul_(scale) + +def int8_forward_channelwise(x, weight, weight_scale, bias=None): + x_8, x_scale = quantize_int8_channelwise(x) + res = torch._int_mm(x_8, weight.T) + res_scaled = res.to(x.dtype).mul_(weight_scale * x_scale) + if bias is not None: + res_scaled.add_(bias.to(x.dtype)) + return res_scaled + + +def fp8_forward_channelwise(x, weight, weight_scale, bias=None): + x_8, x_scale = quantize_fp8_channelwise(x) + one = torch.ones(1, device=x.device) + res = torch._scaled_mm(x_8, weight.T, scale_a=one, scale_b=weight_scale.float(), out_dtype=x.dtype) + res_scaled = res.mul_(x_scale) #much faster than scaled by _scaled_mm + if bias is not None: + res_scaled.add_(bias.to(x.dtype)) + return res_scaled + + +def apply_scale(mm_res, weight_scale, x_scale, compute_dtype): + return mm_res.to(compute_dtype).mul_(weight_scale * x_scale) + +def int8_backward_W_tensorwise_A_channelwise(x, weight, weight_scale): + x_8, x_scale = quantize_int8_channelwise(x) + mm_res = triton_mm_8bit(x_8, weight) + return apply_scale(mm_res, weight_scale, x_scale, compute_dtype=x.dtype) + +def fp8_backward_W_tensorwise_A_channelwise(x, weight, weight_scale): + x_8, x_scale = quantize_fp8_channelwise(x) + mm_res = triton_mm_8bit(x_8, weight) + return apply_scale(mm_res, weight_scale, x_scale, compute_dtype=x.dtype) + + +class LinearInt8Function(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, weight_scale, bias): + ctx.save_for_backward(weight, weight_scale) + return int8_forward_channelwise(x, weight, weight_scale, bias) + + @staticmethod + def backward(ctx, x): + if ctx.needs_input_grad != (True, False, False, False): + raise NotImplementedError("Int A8W8 cannot be used for full finetuning") + + weight, weight_scale = ctx.saved_tensors + return int8_backward_W_tensorwise_A_channelwise(x, weight, weight_scale), None, None, None + +class LinearFp8Function(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, weight_scale, bias): + ctx.save_for_backward(weight, weight_scale) + return fp8_forward_channelwise(x.bfloat16(), weight, weight_scale, bias).bfloat16() + + @staticmethod + def backward(ctx, x): + if ctx.needs_input_grad != (True, False, False, False): + raise NotImplementedError("Float W8A8 cannot be used for full finetuning") + + weight, weight_scale = ctx.saved_tensors + return fp8_backward_W_tensorwise_A_channelwise(x, weight, weight_scale), None, None, None + +class LinearW8A8( + nn.Linear, + QuantizedModuleMixin, + QuantizedLinearMixin, +): + is_quantized: bool + + def __init__(self, dtype, compute_dtype, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_quantized = False + + assert dtype in [torch.int8, torch.float8_e4m3fn] + self._dtype = dtype + self._compute_dtype = compute_dtype + + self._scale = torch.tensor(1.0, dtype=torch.float32) + self.register_buffer("scale", self._scale) + + + def original_weight_shape(self) -> tuple[int, ...]: + return self.weight.shape + + def unquantized_weight(self, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + if self._scale is not None: + return unquantize(self.weight.detach(), self._scale, self._compute_dtype).to(dtype) + else: + return self.weight.detach().to(dtype) + + def quantize(self, device: torch.device | None = None, **kwargs): + if self.is_quantized: + return + self.is_quantized = True + + self.weight.requires_grad_(False) + weight = self.weight.data + orig_device = weight.device + if weight.dtype != self._dtype: + if device is not None: + weight = weight.to(device=device) + + if self._dtype == torch.int8: + weight, self._scale = quantize_int8_tensorwise(weight) + else: + weight, self._scale = quantize_fp8_tensorwise(weight) + + if device is not None: + weight = weight.to(device=orig_device) + self.weight.data = weight + + def forward(self, x_orig: torch.Tensor) -> torch.Tensor: + x = x_orig.to(self._compute_dtype).reshape(-1, x_orig.shape[-1]) + + if x.shape[0] > 16: + if self._dtype == torch.int8: + y = LinearInt8Function.apply(x, self.weight, self._scale, self.bias) + else: + y = LinearFp8Function.apply(x, self.weight, self._scale, self.bias) + else: + w = unquantize(self.weight, self._scale, compute_dtype=self._compute_dtype) + y = torch.nn.functional.linear(x, w, self.bias.to(self._compute_dtype)) + + assert y.dtype == self._compute_dtype + return y.reshape(x_orig.shape[:-1] + (self.weight.shape[0], )) + + + + +def run_benchmark(fn, desc, steps=10000, warmup=500): + from tqdm import tqdm + for _ in range(warmup): + fn() + torch.cuda.synchronize() + for _ in tqdm(range(steps), desc=desc): + fn() + torch.cuda.synchronize() + + +@torch.no_grad() +def benchmark_int8(m, k, n, device = "cuda"): + device = "cuda" + + x = torch.randn(m,k, device=device, dtype=torch.bfloat16) + x_8 = torch.ones (m,k, device=device, dtype=torch.int8) + y = torch.randn(m,n, device=device, dtype=torch.bfloat16) + y_8 = torch.ones (m,n, device=device, dtype=torch.int8) + w_8 = torch.ones (n,k, device=device, dtype=torch.int8) + w_scale = torch.ones(1, device=device) + + + run_benchmark(lambda: torch._int_mm(x_8, w_8.T), "torch mm int") + run_benchmark(lambda: triton_mm_8bit(x_8, w_8.T), "triton mm int") + def torch_backward(a, b): + torch._int_mm(a, b.T.contiguous().T) + run_benchmark(lambda: torch_backward(y_8, w_8), "torch mm backward int8") + run_benchmark(lambda: triton_mm_8bit(y_8, w_8), "triton mm backward int8") + + run_benchmark(lambda: int8_forward_channelwise(x, w_8, w_scale), "torch forward int") + run_benchmark(lambda: int8_backward_W_tensorwise_A_channelwise(y, w_8, w_scale), "triton backward int") + + +@torch.no_grad() +def benchmark_fp8(m, k, n, device = "cuda"): + x = torch.randn(m,k, device=device, dtype=torch.bfloat16) + x_8 = torch.ones (m,k, device=device, dtype=torch.float8_e4m3fn) + y = torch.randn(m,n, device=device, dtype=torch.bfloat16) + y_8 = torch.ones (m,n, device=device, dtype=torch.float8_e4m3fn) + w_8 = torch.ones (n,k, device=device, dtype=torch.float8_e4m3fn) + w_scale = torch.ones(1, device=device, dtype=torch.bfloat16) + one_scale = torch.ones(1, device=device) + + run_benchmark(lambda: torch._scaled_mm(x_8, w_8.T, out_dtype=torch.bfloat16, scale_a=one_scale.float(), scale_b=w_scale.float()), "torch mm fp8") + run_benchmark(lambda: triton_mm_8bit(x_8, w_8.T), "triton mm fp8") + def torch_backward(a, b): + torch._scaled_mm(a, b.T.contiguous().T, out_dtype=torch.bfloat16, scale_a=one_scale.float(), scale_b=w_scale.float()) + run_benchmark(lambda: torch_backward(y_8, w_8), "torch mm backward fp8") + run_benchmark(lambda: triton_mm_8bit(y_8, w_8), "triton mm backward fp8") + run_benchmark(lambda: fp8_forward_channelwise(x, w_8, w_scale), "torch forward fp8") + run_benchmark(lambda: fp8_backward_W_tensorwise_A_channelwise(y, w_8, w_scale), "triton backward fp8") + + +if __name__ == "__main__": + benchmark_int8(2 * 1024 + 50, 3072, 3072 + 16) + benchmark_fp8(2 * 1024 + 50, 3072, 3072 + 16) diff --git a/modules/trainer/GenericTrainer.py b/modules/trainer/GenericTrainer.py index 7dfa76398..cd847a3b3 100644 --- a/modules/trainer/GenericTrainer.py +++ b/modules/trainer/GenericTrainer.py @@ -28,7 +28,7 @@ from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.TimeUnit import TimeUnit from modules.util.enum.TrainingMethod import TrainingMethod -from modules.util.memory_util import TorchMemoryRecorder +from modules.util.profiling_util import TorchMemoryRecorder, TorchProfiler from modules.util.time_util import get_string_timestamp from modules.util.torch_util import torch_gc from modules.util.TrainProgress import TrainProgress @@ -620,10 +620,11 @@ def train(self): has_gradient = False lr_scheduler = None - accumulated_loss = 0.0 + accumulated_loss = torch.tensor(0.0, device=train_device) ema_loss = None ema_loss_steps = 0 epochs = range(train_progress.epoch, self.config.epochs, 1) + for _epoch in tqdm(epochs, desc="epoch") if multi.is_master() else epochs: self.callbacks.on_update_status("Starting epoch/caching") @@ -712,7 +713,7 @@ def sample_commands_fun(): self.callbacks.on_update_status("Training ...") - with TorchMemoryRecorder(enabled=False): + with TorchMemoryRecorder(enabled=False), TorchProfiler(enabled=False, filename=f"step{train_progress.global_step}.json"): step_seed = train_progress.global_step bf16_stochastic_rounding_set_seed(step_seed, train_device) @@ -744,7 +745,7 @@ def sample_commands_fun(): has_gradient = True detached_loss = loss.detach() multi.reduce_tensor_mean(detached_loss) - accumulated_loss += detached_loss.item() + accumulated_loss += detached_loss if self.__is_update_step(train_progress): if self.config.fused_gradient_reduce: @@ -775,13 +776,13 @@ def sample_commands_fun(): self.model, self.config, lr_scheduler, self.tensorboard ) - self.tensorboard.add_scalar("loss/train_step", accumulated_loss, train_progress.global_step) - ema_loss = ema_loss or accumulated_loss + self.tensorboard.add_scalar("loss/train_step", accumulated_loss.item(), train_progress.global_step) + ema_loss = ema_loss or accumulated_loss.item() ema_loss_steps += 1 ema_loss_decay = min(0.99, 1 - (1 / ema_loss_steps)) - ema_loss = (ema_loss * ema_loss_decay) + (accumulated_loss * (1 - ema_loss_decay)) + ema_loss = (ema_loss * ema_loss_decay) + (accumulated_loss.item() * (1 - ema_loss_decay)) step_tqdm.set_postfix({ - 'loss': accumulated_loss, + 'loss': accumulated_loss.item(), 'smooth loss': ema_loss, }) self.tensorboard.add_scalar("smooth_loss/train_step", ema_loss, train_progress.global_step) diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index d27118bb3..c938382bb 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -257,16 +257,26 @@ def __setup_hi_dream_ui(self): allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __create_dtype_options(self, include_none:bool=True) -> list[tuple[str, DataType]]: + def __create_dtype_options(self, include_none: bool=True, include_svd: bool=False) -> list[tuple[str, DataType]]: options = [ ("float32", DataType.FLOAT_32), ("bfloat16", DataType.BFLOAT_16), ("float16", DataType.FLOAT_16), - ("float8", DataType.FLOAT_8), + ("float8 (W8)", DataType.FLOAT_8), + ("float W8A8", DataType.FLOAT_W8A8), + ("int W8A8", DataType.INT_W8A8), # ("int8", DataType.INT_8), # TODO: reactivate when the int8 implementation is fixed in bitsandbytes: https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1332 ("nfloat4", DataType.NFLOAT_4), ] + if include_svd: + options += [ + ("float8 (W8) SVDQuant", DataType.FLOAT_8_SVD), + ("float W8A8 SVDQuant", DataType.FLOAT_W8A8_SVD), + ("int W8A8 SVDQuant", DataType.INT_W8A8_SVD), + ("nfloat4 SVDQuant", DataType.NFLOAT_4_SVD), + ] + if include_none: options.insert(0, ("", DataType.NONE)) @@ -280,8 +290,6 @@ def __create_base_dtype_components(self, row: int) -> int: wide_tooltip=True) components.entry(self.scroll_frame, row, 1, self.ui_state, "secrets.huggingface_token") - row += 1 - # base model components.label(self.scroll_frame, row, 0, "Base Model", tooltip="Filename, directory or Hugging Face repository of the base model") @@ -293,6 +301,7 @@ def __create_base_dtype_components(self, row: int) -> int: # weight dtype components.label(self.scroll_frame, row, 3, "Weight Data Type", tooltip="The base model weight data type used for training. This can reduce memory consumption, but reduces precision") + components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(False), self.ui_state, "weight_dtype") @@ -336,11 +345,31 @@ def __create_base_components( # prior weight dtype components.label(self.scroll_frame, row, 3, "Override Prior Data Type", tooltip="Overrides the prior weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(include_svd=True), self.ui_state, "prior.weight_dtype") row += 1 + # compile + components.label(self.scroll_frame, row, 3, "Compile transformer blocks", + tooltip="Uses torch.compile and Triton to significantly speed up training. Only applies to transformer/unet. Disable in case of compatibility issues.") + components.switch(self.scroll_frame, row, 4, self.ui_state, "compile") + + row += 1 + + # SVDQuant + components.label(self.scroll_frame, row, 3, "SVDQuant Data Type", + tooltip="What datatype to use for SVDQuant weights decomposition.") + components.options_kv(self.scroll_frame, row, 4, [("float32", DataType.FLOAT_32), ("bfloat16", DataType.BFLOAT_16)], + self.ui_state, "svd_dtype") + + row += 1 + components.label(self.scroll_frame, row, 3, "SVDQuant Rank", + tooltip="Rank for SVDQuant weights decomposition") + components.entry(self.scroll_frame, row, 4, self.ui_state, "svd_rank") + + row += 1 + if has_text_encoder: # text encoder weight dtype components.label(self.scroll_frame, row, 3, "Override Text Encoder Data Type", diff --git a/modules/ui/TrainUI.py b/modules/ui/TrainUI.py index d3da16162..1b6a49d10 100644 --- a/modules/ui/TrainUI.py +++ b/modules/ui/TrainUI.py @@ -754,6 +754,7 @@ def __training_thread_function(self): trainer.start() if self.train_config.cloud.enabled: self.ui_state.get_var("secrets.cloud").update(self.train_config.secrets.cloud) + self.start_time = time.monotonic() trainer.train() except Exception: diff --git a/modules/util/checkpointing_util.py b/modules/util/checkpointing_util.py index c015e9060..ac5e17a40 100644 --- a/modules/util/checkpointing_util.py +++ b/modules/util/checkpointing_util.py @@ -10,29 +10,9 @@ from torch import nn from torch.utils.checkpoint import checkpoint -from diffusers.models.attention import BasicTransformerBlock, JointTransformerBlock -from diffusers.models.transformers.sana_transformer import SanaTransformerBlock -from diffusers.models.transformers.transformer_chroma import ChromaSingleTransformerBlock, ChromaTransformerBlock -from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock -from diffusers.models.transformers.transformer_hidream_image import ( - HiDreamImageSingleTransformerBlock, - HiDreamImageTransformerBlock, -) -from diffusers.models.transformers.transformer_hunyuan_video import ( - HunyuanVideoIndividualTokenRefinerBlock, - HunyuanVideoSingleTransformerBlock, - HunyuanVideoTransformerBlock, -) -from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformerBlock -from diffusers.models.unets.unet_stable_cascade import SDCascadeAttnBlock, SDCascadeResBlock, SDCascadeTimestepBlock -from transformers.models.clip.modeling_clip import CLIPEncoderLayer -from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer -from transformers.models.llama.modeling_llama import LlamaDecoderLayer -from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLDecoderLayer -from transformers.models.t5.modeling_t5 import T5Block - - -def __kwargs_to_args(fun: Callable, args: tuple[Any, ...], kwargs: dict[str, Any]) -> tuple[Any, ...]: +torch._dynamo.config.cache_size_limit = 8192 + +def _kwargs_to_args(fun: Callable, args: tuple[Any, ...], kwargs: dict[str, Any]) -> tuple[Any, ...]: signature = dict(inspect.signature(fun).parameters) parameters = [] @@ -61,430 +41,277 @@ def __get_args_indices(fun: Callable, arg_names: list[str]) -> list[int]: __current_call_index = 0 -def __generate_call_index() -> int: +def _generate_call_index() -> int: global __current_call_index __current_call_index += 1 return __current_call_index -def create_checkpointed_forward( +class CheckpointLayer(torch.nn.Module): + def __init__(self, orig: nn.Module, train_device: torch.device): + super().__init__() + self.orig = orig + # dummy tensor that requires grad is needed for checkpointing to work when training a LoRA + self.dummy = torch.zeros((1,), device=train_device, requires_grad=True) + #self.orig.compile(fullgraph=True) + + def __checkpointing_forward(self, dummy: torch.Tensor, *args, **kwargs): + return self.orig(*args, **kwargs) + + def forward(self, *args, **kwargs): + if torch.is_grad_enabled(): + return checkpoint( + self.__checkpointing_forward, + self.dummy, + *args, + **kwargs, + use_reentrant=False + ) + else: + return self.orig(*args, **kwargs) + + +class OffloadCheckpointLayer(torch.nn.Module): + def __init__(self, orig: nn.Module, train_device: torch.device, conductor: LayerOffloadConductor, layer_index: int): + super().__init__() + self.orig = orig + self.dummy = torch.zeros((1,), device=train_device, requires_grad=True) + self.conductor = conductor + self.layer_index = layer_index + + def __checkpointing_forward(self, dummy: torch.Tensor, call_id: int, *args): + + if self.layer_index == 0 and not torch.is_grad_enabled(): + self.conductor.start_forward(True) + + args = self.conductor.before_layer(self.layer_index, call_id, args) + output = self.orig(*args) + self.conductor.after_layer(self.layer_index, call_id, args) + + # make sure at least one of the output tensors has a grad_fn so the output of the checkpoint has a grad_fn + assert not (torch.is_grad_enabled() and not has_grad_fn(output)) + #TODO how can this be the case? Is there a backward that does not produce gradients wrt to any of its inputs? + #if it be the case, TODO check that add_dummy_grad_fn_ still works with torch.compile + if torch.is_grad_enabled() and not has_grad_fn(output): + output = add_dummy_grad_fn_(output) + + return output + + def forward(self, *args, **kwargs): + call_id = _generate_call_index() + args = _kwargs_to_args(self.orig.forward, args, kwargs) + + if torch.is_grad_enabled(): + return checkpoint( + self.__checkpointing_forward, + self.dummy, + call_id, + *args, + use_reentrant=True + ) + else: + if self.layer_index == 0: + self.conductor.start_forward(False) + + args = self.conductor.before_layer(self.layer_index, call_id, args) + output = self.orig(*args) + self.conductor.after_layer(self.layer_index, call_id, args) + return output + + +def create_checkpoint( orig_module: nn.Module, train_device: torch.device, include_from_offload_param_names: list[str] = None, conductor: LayerOffloadConductor | None = None, layer_index: int = 0, + compile: bool = False, + enabled: bool = True, ) -> Callable: - orig_forward = orig_module.forward if include_from_offload_param_names is None: include_from_offload_param_names = [] - included_offload_param_indices = __get_args_indices(orig_forward, include_from_offload_param_names) + included_offload_param_indices = __get_args_indices(orig_module.forward, include_from_offload_param_names) - bound_conductor = conductor - bound_layer_index = layer_index if conductor is not None: conductor.add_layer(orig_module, included_offload_param_indices) if conductor is not None and conductor.offload_activated(): - def offloaded_custom_forward( - # dummy tensor that requires grad is needed for checkpointing to work when training a LoRA - dummy: torch.Tensor, - call_id: int, - *args, - ): - if bound_layer_index == 0 and not torch.is_grad_enabled(): - bound_conductor.start_forward(True) - - args = bound_conductor.before_layer(bound_layer_index, call_id, args) - output = orig_forward(*args) - bound_conductor.after_layer(bound_layer_index, call_id, args) - - # make sure at least one of the output tensors has a grad_fn so the output of the checkpoint has a grad_fn - if torch.is_grad_enabled() and not has_grad_fn(output): - output = add_dummy_grad_fn_(output) - - return output - - def custom_forward( - call_index: int, - *args, - ): - if bound_layer_index == 0: - bound_conductor.start_forward(False) - - args = bound_conductor.before_layer(bound_layer_index, call_index, args) - output = orig_forward(*args) - bound_conductor.after_layer(bound_layer_index, call_index, args) - return output - - def forward( - *args, - **kwargs - ): - call_id = __generate_call_index() - - if torch.is_grad_enabled(): - dummy = torch.zeros((1,), device=train_device) - dummy.requires_grad_(True) - - args = __kwargs_to_args(orig_forward, args, kwargs) - - return checkpoint( - offloaded_custom_forward, - dummy, - call_id, - *args, - use_reentrant=True - ) - else: - args = __kwargs_to_args(orig_forward, args, kwargs) - return custom_forward(call_id, *args) + layer = OffloadCheckpointLayer(orig_module, train_device, conductor, layer_index) + if compile: + #don't compile the checkpointing layer - offloading cannot be compiled: + orig_module.compile(fullgraph=True) else: - def custom_forward( - # dummy tensor that requires grad is needed for checkpointing to work when training a LoRA - dummy: torch.Tensor = None, - *args, - **kwargs, - ): - return orig_forward( - *args, - **kwargs, + layer = CheckpointLayer(orig_module, train_device) if enabled else orig_module + if compile: + #do compile the checkpointing layer - slightly faster + layer.compile(fullgraph=True) + return layer + +def _create_checkpoints_for_module_list( + module_list: nn.ModuleList, + include_from_offload_param_names: list[str], + conductor: LayerOffloadConductor, + train_device: torch.device, + layer_index: int, + compile: bool, +) -> int: + + for i, layer in enumerate(module_list): + module_list[i] = create_checkpoint( + layer, train_device, + include_from_offload_param_names, + conductor, layer_index, compile=compile, ) + layer_index += 1 + return layer_index - def forward( - *args, - **kwargs - ): - if torch.is_grad_enabled(): - dummy = torch.zeros((1,), device=train_device) - dummy.requires_grad_(True) - return checkpoint( - custom_forward, - dummy, - *args, - **kwargs, - use_reentrant=False - ) - else: - return custom_forward(None, *args, **kwargs) - - return forward - - -def enable_checkpointing_for_basic_transformer_blocks( - orig_module: nn.Module, +def enable_checkpointing( + model: nn.Module, config: TrainConfig, - offload_enabled: bool, + compile: bool, + lists, + offload_enabled: bool = True, ) -> LayerOffloadConductor: - conductor = LayerOffloadConductor(orig_module, config) + conductor = LayerOffloadConductor(model, config) layer_index = 0 - for child_module in orig_module.modules(): - if isinstance(child_module, BasicTransformerBlock): - if offload_enabled: - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - [], - conductor, layer_index, - ) - else: - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - [], - ) - layer_index += 1 + for module_list, param_names in lists: + layer_index = _create_checkpoints_for_module_list( + module_list, + param_names, + conductor if offload_enabled else None, + torch.device(config.train_device), + layer_index, + compile = compile, + ) return conductor +#TODO test all models +def enable_checkpointing_for_basic_transformer_blocks( + model: nn.Module, + config: TrainConfig, + offload_enabled: bool, +) -> LayerOffloadConductor: + return enable_checkpointing(model, config, config.compile, [ + (model.transformer_blocks, []), + ], + offload_enabled = offload_enabled, + ) + def enable_checkpointing_for_clip_encoder_layers( - orig_module: nn.Module, + model: nn.Module, config: TrainConfig, ): - for child_module in orig_module.modules(): - if isinstance(child_module, CLIPEncoderLayer): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - [], - ) - + return enable_checkpointing(model, config, False, [ + (model.text_model.encoder.layers, []), # No activation offloading for text encoders, because the output might be taken from the middle of the network + ]) def enable_checkpointing_for_stable_cascade_blocks( - orig_module: nn.Module, + model: nn.Module, config: TrainConfig, ) -> LayerOffloadConductor: - conductor = LayerOffloadConductor(orig_module, config) - - layer_index = 0 - for child_module in orig_module.modules(): - if isinstance(child_module, SDCascadeResBlock): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - [], - conductor, layer_index, - ) - layer_index += 1 - if isinstance(child_module, SDCascadeAttnBlock): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - [], - conductor, layer_index, - ) - layer_index += 1 - if isinstance(child_module, SDCascadeTimestepBlock): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - [], - conductor, layer_index, - ) - layer_index += 1 - - return conductor - + return enable_checkpointing(model, config, config.compile, [ + (model.down_blocks, []), + (model.up_blocks, []), + ]) def enable_checkpointing_for_t5_encoder_layers( - orig_module: nn.Module, + model: nn.Module, config: TrainConfig, ) -> LayerOffloadConductor: - conductor = LayerOffloadConductor(orig_module, config) - - layer_index = 0 - for child_module in orig_module.modules(): - if isinstance(child_module, T5Block): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - [], # No activation offloading, because the output might be taken from the middle of the network - conductor, layer_index, - ) - layer_index += 1 - - return conductor + return enable_checkpointing(model, config, False, [ + (model.encoder.block, []), + ]) def enable_checkpointing_for_gemma_layers( - orig_module: nn.Module, + model: nn.Module, config: TrainConfig, ) -> LayerOffloadConductor: - conductor = LayerOffloadConductor(orig_module, config) - - layer_index = 0 - for child_module in orig_module.modules(): - if isinstance(child_module, Gemma2DecoderLayer): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - [], # No activation offloading, because the output might be taken from the middle of the network - conductor, layer_index, - ) - layer_index += 1 - - return conductor + return enable_checkpointing(model, config, False, [ + (model.layers, []), + ]) def enable_checkpointing_for_llama_encoder_layers( - orig_module: nn.Module, + model: nn.Module, config: TrainConfig, ) -> LayerOffloadConductor: - conductor = LayerOffloadConductor(orig_module, config) - - layer_index = 0 - for child_module in orig_module.modules(): - if isinstance(child_module, LlamaDecoderLayer): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - [], # No activation offloading, because the output might be taken from the middle of the network - conductor, layer_index, - ) - layer_index += 1 - - return conductor + return enable_checkpointing(model, config, False, [ + (model.model.layers, []), + ]) def enable_checkpointing_for_qwen_encoder_layers( - orig_module: nn.Module, + model: nn.Module, config: TrainConfig, ) -> LayerOffloadConductor: - conductor = LayerOffloadConductor(orig_module, config) - - layer_index = 0 - for child_module in orig_module.modules(): - if isinstance(child_module, Qwen2_5_VLDecoderLayer): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - [], # TODO No activation offloading for other encoders, see above. But clip skip is not implemented for QwenVL. Then do activation offloading? - conductor, layer_index, - ) - layer_index += 1 - - return conductor - + return enable_checkpointing(model, config, False, [ + (model.model.language_model.layers, []), # TODO No activation offloading for other encoders, see above. But clip skip is not implemented for QwenVL. Then do activation offloading? + ]) def enable_checkpointing_for_stable_diffusion_3_transformer( - orig_module: nn.Module, + model: nn.Module, config: TrainConfig, ) -> LayerOffloadConductor: - conductor = LayerOffloadConductor(orig_module, config) - - layer_index = 0 - for child_module in orig_module.modules(): - if isinstance(child_module, JointTransformerBlock): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - ["hidden_states", "encoder_hidden_states"], - conductor, layer_index, - ) - layer_index += 1 - - return conductor - + return enable_checkpointing(model, config, config.compile, [ + (model.transformer_blocks, ["hidden_states", "encoder_hidden_states"]), + ]) def enable_checkpointing_for_flux_transformer( - orig_module: nn.Module, + model: nn.Module, config: TrainConfig, ) -> LayerOffloadConductor: - conductor = LayerOffloadConductor(orig_module, config) - - layer_index = 0 - for child_module in orig_module.modules(): - if isinstance(child_module, FluxTransformerBlock): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - ["hidden_states", "encoder_hidden_states"], - conductor, layer_index, - ) - layer_index += 1 - - for child_module in orig_module.modules(): - if isinstance(child_module, FluxSingleTransformerBlock): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - ["hidden_states"], - conductor, layer_index, - ) - layer_index += 1 + return enable_checkpointing(model, config, config.compile, [ + (model.transformer_blocks, ["hidden_states", "encoder_hidden_states"]), + (model.single_transformer_blocks, ["hidden_states" ]), + ]) - return conductor def enable_checkpointing_for_chroma_transformer( - orig_module: nn.Module, + model: nn.Module, config: TrainConfig, ) -> LayerOffloadConductor: - conductor = LayerOffloadConductor(orig_module, config) - - layer_index = 0 - for child_module in orig_module.modules(): - if isinstance(child_module, ChromaTransformerBlock): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - ["hidden_states", "encoder_hidden_states"], - conductor, layer_index, - ) - layer_index += 1 - - for child_module in orig_module.modules(): - if isinstance(child_module, ChromaSingleTransformerBlock): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - ["hidden_states"], - conductor, layer_index, - ) - layer_index += 1 + return enable_checkpointing(model, config, config.compile, [ + (model.transformer_blocks, ["hidden_states", "encoder_hidden_states"]), + (model.single_transformer_blocks, ["hidden_states" ]), + ]) - return conductor def enable_checkpointing_for_qwen_transformer( - orig_module: nn.Module, + model: nn.Module, config: TrainConfig, ) -> LayerOffloadConductor: - conductor = LayerOffloadConductor(orig_module, config) - - layer_index = 0 - for child_module in orig_module.modules(): - if isinstance(child_module, QwenImageTransformerBlock): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - ["hidden_states", "encoder_hidden_states"], - conductor, layer_index, - ) - layer_index += 1 + return enable_checkpointing(model, config, config.compile, [ + (model.transformer_blocks, ["hidden_states", "encoder_hidden_states"]), + ]) - return conductor def enable_checkpointing_for_sana_transformer( - orig_module: nn.Module, + model: nn.Module, config: TrainConfig, ) -> LayerOffloadConductor: - conductor = LayerOffloadConductor(orig_module, config) - - layer_index = 0 - for child_module in orig_module.modules(): - if isinstance(child_module, SanaTransformerBlock): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - ["hidden_states"], - conductor, layer_index, - ) - layer_index += 1 - - return conductor + return enable_checkpointing(model, config, config.compile, [ + (model.transformer_blocks, ["hidden_states"]), + ]) def enable_checkpointing_for_hunyuan_video_transformer( - orig_module: nn.Module, + model: nn.Module, config: TrainConfig, ) -> LayerOffloadConductor: - conductor = LayerOffloadConductor(orig_module, config) - - layer_index = 0 - for child_module in orig_module.modules(): - if isinstance(child_module, HunyuanVideoIndividualTokenRefinerBlock): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - ["hidden_states"], - conductor, layer_index, - ) - layer_index += 1 - - for child_module in orig_module.modules(): - if isinstance(child_module, HunyuanVideoTransformerBlock): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - ["hidden_states", "encoder_hidden_states"], - conductor, layer_index, - ) - layer_index += 1 - - for child_module in orig_module.modules(): - if isinstance(child_module, HunyuanVideoSingleTransformerBlock): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - ["hidden_states"], - conductor, layer_index, - ) - layer_index += 1 - - return conductor + return enable_checkpointing(model, config, config.compile, [ + (model.context_embedder.token_refiner.refiner_blocks, ["hidden_states" ]), + (model.transformer_blocks, ["hidden_states", "encoder_hidden_states"]), + (model.single_transformer_blocks, ["hidden_states" ]), + ]) def enable_checkpointing_for_hi_dream_transformer( - orig_module: nn.Module, + model: nn.Module, config: TrainConfig, ) -> LayerOffloadConductor: - conductor = LayerOffloadConductor(orig_module, config) - - layer_index = 0 - for child_module in orig_module.modules(): - if isinstance(child_module, HiDreamImageTransformerBlock): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - ["hidden_states", "encoder_hidden_states"], - conductor, layer_index, - ) - layer_index += 1 - - for child_module in orig_module.modules(): - if isinstance(child_module, HiDreamImageSingleTransformerBlock): - child_module.forward = create_checkpointed_forward( - child_module, torch.device(config.train_device), - ["hidden_states"], - conductor, layer_index, - ) - layer_index += 1 - - return conductor + return enable_checkpointing(model, config, config.compile, [ + (model.double_stream_blocks, ["hidden_states", "encoder_hidden_states"]), + (model.single_stream_blocks, ["hidden_states" ]), + ]) diff --git a/modules/util/config/TrainConfig.py b/modules/util/config/TrainConfig.py index 220fcd02a..9b6419fa0 100644 --- a/modules/util/config/TrainConfig.py +++ b/modules/util/config/TrainConfig.py @@ -328,6 +328,9 @@ class TrainConfig(BaseConfig): enable_activation_offloading: bool layer_offload_fraction: float force_circular_padding: bool + compile: bool + svd_dtype: DataType + svd_rank: int # data settings concept_file_name: str @@ -867,6 +870,9 @@ def default_values() -> 'TrainConfig': data.append(("enable_activation_offloading", True, bool, False)) data.append(("layer_offload_fraction", 0.0, float, False)) data.append(("force_circular_padding", False, bool, False)) + data.append(("compile", True, bool, False)) + data.append(("svd_dtype", DataType.FLOAT_32, DataType, False)) + data.append(("svd_rank", 16, int, False)) # data settings data.append(("concept_file_name", "training_concepts/concepts.json", str, False)) diff --git a/modules/util/enum/DataType.py b/modules/util/enum/DataType.py index 45e547507..4289adfd0 100644 --- a/modules/util/enum/DataType.py +++ b/modules/util/enum/DataType.py @@ -12,6 +12,12 @@ class DataType(Enum): TFLOAT_32 = 'TFLOAT_32' INT_8 = 'INT_8' NFLOAT_4 = 'NFLOAT_4' + FLOAT_W8A8 = 'FLOAT_W8A8' + INT_W8A8 = 'INT_W8A8' + FLOAT_8_SVD = 'FLOAT_8_SVD' + NFLOAT_4_SVD = 'NFLOAT_4_SVD' + FLOAT_W8A8_SVD = 'FLOAT_W8A8_SVD' + INT_W8A8_SVD = 'INT_W8A8_SVD' def __str__(self): return self.value @@ -41,13 +47,31 @@ def enable_tf(self): def is_quantized(self): return self in [DataType.FLOAT_8, DataType.INT_8, - DataType.NFLOAT_4] + DataType.FLOAT_W8A8, + DataType.INT_W8A8, + DataType.NFLOAT_4, + DataType.FLOAT_8_SVD, + DataType.FLOAT_W8A8_SVD, + DataType.INT_W8A8_SVD, + DataType.NFLOAT_4_SVD] def quantize_fp8(self): - return self == DataType.FLOAT_8 + return self == DataType.FLOAT_8 or self == DataType.FLOAT_8_SVD def quantize_int8(self): return self == DataType.INT_8 + def quantize_fpW8A8(self): + return self == DataType.FLOAT_W8A8 or self == DataType.FLOAT_W8A8_SVD + + def quantize_intW8A8(self): + return self == DataType.INT_W8A8 or self == DataType.INT_W8A8_SVD + def quantize_nf4(self): - return self == DataType.NFLOAT_4 + return self == DataType.NFLOAT_4 or self == DataType.NFLOAT_4_SVD + + def quantize_svd(self): + return self in [DataType.FLOAT_8_SVD, + DataType.NFLOAT_4_SVD, + DataType.FLOAT_W8A8_SVD, + DataType.INT_W8A8_SVD] diff --git a/modules/util/memory_util.py b/modules/util/memory_util.py deleted file mode 100644 index ef93ee0f1..000000000 --- a/modules/util/memory_util.py +++ /dev/null @@ -1,24 +0,0 @@ -import platform - -import torch - - -class TorchMemoryRecorder: - - def __init__(self, filename: str = "memory.pickle", enabled: bool = True): - self.filename = filename - self.enabled = enabled and platform.system() == 'Linux' - - def __enter__(self): - if self.enabled: - torch.cuda.memory._record_memory_history() - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.enabled: - try: - torch.cuda.memory._dump_snapshot(filename=self.filename) - print(f"dumped memory snapshot to {self.filename}") - except Exception: - print(f"could not dump memory snapshot {self.filename}") - - torch.cuda.memory._record_memory_history(enabled=None) diff --git a/modules/util/profiling_util.py b/modules/util/profiling_util.py new file mode 100644 index 000000000..d09a19dab --- /dev/null +++ b/modules/util/profiling_util.py @@ -0,0 +1,53 @@ +import platform + +import torch + + +class TorchMemoryRecorder: + def __init__(self, filename: str = "memory.pickle", enabled: bool = True): + self.filename = filename + self.enabled = enabled and platform.system() == 'Linux' + + def __enter__(self): + if self.enabled: + torch.cuda.memory._record_memory_history() + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.enabled: + try: + torch.cuda.memory._dump_snapshot(filename=self.filename) + print(f"dumped memory snapshot to {self.filename}") + except Exception: + print(f"could not dump memory snapshot {self.filename}") + + torch.cuda.memory._record_memory_history(enabled=None) + +class TorchProfiler: + def __init__(self, filename: str, enabled: bool = True): + self.filename = filename + self.enabled = enabled + self.profiler = None + + def __enter__(self): + if self.enabled: + profiler_context = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA + ], + ) + self.profiler = profiler_context.__enter__() + return self.profiler + else: + return None + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.profiler is not None: + ret = self.profiler.__exit__(exc_type, exc_val, exc_tb) + try: + self.profiler.export_chrome_trace(self.filename) + except Exception: + print(f"could not write profiler output {self.filename}") + return ret + else: + return False diff --git a/modules/util/quantization_util.py b/modules/util/quantization_util.py index d57789a22..168313ea9 100644 --- a/modules/util/quantization_util.py +++ b/modules/util/quantization_util.py @@ -1,13 +1,21 @@ +import os from collections.abc import Callable +from contextlib import suppress +from functools import partial from modules.module.quantized.LinearFp8 import LinearFp8 +from modules.module.quantized.LinearSVD import BaseLinearSVD, make_svd_linear +from modules.module.quantized.LinearW8A8 import LinearW8A8 from modules.module.quantized.mixin.QuantizedLinearMixin import QuantizedLinearMixin from modules.module.quantized.mixin.QuantizedModuleMixin import QuantizedModuleMixin +from modules.util.config.TrainConfig import TrainConfig from modules.util.enum.DataType import DataType import torch from torch import Tensor, nn +from tqdm import tqdm + try: from modules.module.quantized.LinearNf4 import LinearNf4 @@ -16,46 +24,10 @@ bnb = None LinearNf4 = None - -def __create_nf4_linear_layer(module: nn.Linear, copy_parameters: bool) -> nn.Module: - bias = module.bias is not None - - quant_linear = LinearNf4( - in_features=module.in_features, - out_features=module.out_features, - bias=bias, - ) - - if copy_parameters: - quant_linear.weight.data = module.weight.data - if bias: - quant_linear.bias.data = module.bias.data - - return quant_linear - - -def __create_int8_linear_layer(module: nn.Linear, copy_parameters: bool) -> nn.Module: - bias = module.bias is not None - - quant_linear = bnb.nn.Linear8bitLt( - input_features=module.in_features, - output_features=module.out_features, - bias=bias, - has_fp16_weights=False, - ) - - if copy_parameters: - quant_linear.weight = type(quant_linear.weight)(module.weight) - if bias: - quant_linear.bias = type(quant_linear.bias)(module.bias) - - return quant_linear - - -def __create_fp8_linear_layer(module: nn.Linear, copy_parameters: bool) -> nn.Module: +def __create_linear_layer(construct_fn, module: nn.Linear, copy_parameters: bool) -> nn.Module: bias = module.bias is not None - quant_linear = LinearFp8( + quant_linear = construct_fn( in_features=module.in_features, out_features=module.out_features, bias=bias, @@ -71,7 +43,7 @@ def __create_fp8_linear_layer(module: nn.Linear, copy_parameters: bool) -> nn.Mo def __replace_linear_layers( parent_module: nn.Module, - convert_fn: Callable[[nn.Linear, bool], nn.Module], + construct_fn, keep_in_fp32_modules: list[str] | None = None, copy_parameters: bool = False, name_prefix: str = "", @@ -89,13 +61,13 @@ def __replace_linear_layers( if isinstance(parent_module, nn.ModuleList): for i, module in enumerate(parent_module): if isinstance(module, nn.Linear): - quant_linear = convert_fn(module, copy_parameters) + quant_linear = __create_linear_layer(construct_fn, module, copy_parameters) parent_module[i] = quant_linear del module elif id(module) not in visited_modules: __replace_linear_layers( parent_module=module, - convert_fn=convert_fn, + construct_fn=construct_fn, keep_in_fp32_modules=keep_in_fp32_modules, copy_parameters=copy_parameters, name_prefix=f"{name_prefix}[{i}]", @@ -108,13 +80,13 @@ def __replace_linear_layers( module = getattr(parent_module, attr_name) if isinstance(module, nn.Linear): - quant_linear = convert_fn(module, copy_parameters) + quant_linear = __create_linear_layer(construct_fn, module, copy_parameters) setattr(parent_module, attr_name, quant_linear) del module elif isinstance(module, nn.Module) and id(module) not in visited_modules: __replace_linear_layers( parent_module=module, - convert_fn=convert_fn, + construct_fn=construct_fn, keep_in_fp32_modules=keep_in_fp32_modules, copy_parameters=copy_parameters, name_prefix=f"{name_prefix}.{attr_name}", @@ -122,40 +94,28 @@ def __replace_linear_layers( ) -def replace_linear_with_nf4_layers( +def replace_linear_with_quantized_layers( parent_module: nn.Module, + dtype: DataType, keep_in_fp32_modules: list[str] | None = None, copy_parameters: bool = False, ): - __replace_linear_layers( - parent_module=parent_module, - convert_fn=__create_nf4_linear_layer, - keep_in_fp32_modules=keep_in_fp32_modules, - copy_parameters=copy_parameters, - ) - - -def replace_linear_with_int8_layers( - parent_module: nn.Module, - keep_in_fp32_modules: list[str] | None = None, - copy_parameters: bool = False, -): - __replace_linear_layers( - parent_module=parent_module, - convert_fn=__create_int8_linear_layer, - keep_in_fp32_modules=keep_in_fp32_modules, - copy_parameters=copy_parameters, - ) - + if dtype.quantize_nf4(): + construct_fn = make_svd_linear(LinearNf4) if dtype.quantize_svd() else LinearNf4 + elif dtype.quantize_int8(): + construct_fn = partial(make_svd_linear(bnb.nn.Linear8bitLt) if dtype.quantize_svd() else bnb.nn.Linear8bitLt, has_fp16_weights=False) + elif dtype.quantize_fp8(): + construct_fn = make_svd_linear(LinearFp8) if dtype.quantize_svd() else LinearFp8 + elif dtype.quantize_intW8A8(): + construct_fn = partial(make_svd_linear(LinearW8A8) if dtype.quantize_svd() else LinearW8A8, dtype=torch.int8, compute_dtype=torch.bfloat16) + elif dtype.quantize_fpW8A8(): + construct_fn = partial(make_svd_linear(LinearW8A8) if dtype.quantize_svd() else LinearW8A8, dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + else: + return -def replace_linear_with_fp8_layers( - parent_module: nn.Module, - keep_in_fp32_modules: list[str] | None = None, - copy_parameters: bool = False, -): __replace_linear_layers( parent_module=parent_module, - convert_fn=__create_fp8_linear_layer, + construct_fn=construct_fn, keep_in_fp32_modules=keep_in_fp32_modules, copy_parameters=copy_parameters, ) @@ -165,6 +125,9 @@ def is_quantized_parameter( module: nn.Module, parameter_name: str, ) -> bool: + if isinstance(module, BaseLinearSVD): + if parameter_name in ["svd_up", "svd_down"]: + return True if bnb is not None: if isinstance(module, LinearNf4): return parameter_name in [ @@ -178,18 +141,22 @@ def is_quantized_parameter( elif isinstance(module, bnb.nn.Linear8bitLt): return parameter_name == "weight" - if isinstance(module, LinearFp8): + if isinstance(module, (LinearFp8, LinearW8A8)): return parameter_name == "weight" return False -def quantize_layers(module: nn.Module, device: torch.device, train_dtype: DataType): +def quantize_layers(module: nn.Module, device: torch.device, train_dtype: DataType, config: TrainConfig): if module is not None: - for child_module in module.modules(): + cache_dir = config.cache_dir + "/quantization" + with suppress(FileExistsError): + os.mkdir(cache_dir) + child_modules = list(module.modules()) + for child_module in tqdm(child_modules, desc="Quantizing model weights", total=len(child_modules), delay=5, smoothing=0.1): if isinstance(child_module, QuantizedModuleMixin): child_module.compute_dtype = train_dtype.torch_dtype() - child_module.quantize(device) + child_module.quantize(device=device, cache_dir=cache_dir, svd_dtype=config.svd_dtype.torch_dtype(), rank=config.svd_rank) def get_unquantized_weight(module: nn.Module, dtype: torch.dtype, device: torch.device) -> Tensor: @@ -219,6 +186,8 @@ def get_offload_tensors(module: nn.Module) -> list[torch.Tensor]: tensors += [module.weight] if isinstance(module, nn.Linear) and module.bias is not None: tensors += [module.bias] + if isinstance(module, BaseLinearSVD): + tensors += [module.svd_up, module.svd_down] return tensors diff --git a/modules/util/triton_mm_8bit.py b/modules/util/triton_mm_8bit.py new file mode 100644 index 000000000..8250e1270 --- /dev/null +++ b/modules/util/triton_mm_8bit.py @@ -0,0 +1,121 @@ +#This is a 8bit matmul kernel adapted from the Triton tutorial here: +#https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + +#It is not optimized and about 10% slower than torch._int_mm and torch._scaled_mm +#However, the torch functions don't work on row-major rhs matrices: +#_scaled_mm fails, _int_mm automatically converts to column-major +# +#Converting to column-major is slow, which is significant because the weights matrix +#of a Linear layer is always column-major during the backward pass. +# +#In these cases, this Triton kernel is *much* faster because it can access the +#row-major weight matrix directly, using strided memory access + +import torch + +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128}, num_stages=3,num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3,num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128}, num_stages=3,num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5,num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5,num_warps=2), + + + ], + key=[ + 'QUANTIZED_M', #only tune roughly on M, because M is the transformer sequence length - can vary on data + 'N', + 'K', + 'stride_bk' #use stride of b as key, to autotune again for a strided rhs matrix (backward pass) + ], +) + +@triton.jit +def __mm_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + QUANTIZED_M, + FLOAT: tl.constexpr, +): + + pid_n = tl.program_id(axis=0) + pid_m = tl.program_id(axis=1) + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32 if FLOAT else tl.int32) + + for k in range(tl.cdiv(K, BLOCK_SIZE_K)): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K - k*BLOCK_SIZE_K) + b_mask = (offs_bn[None, :] < N) & (offs_k[:, None] < K - k*BLOCK_SIZE_K) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + + accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32 if FLOAT else tl.int32) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + +def mm_8bit(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + assert a.dtype == b.dtype, "Incompatible dtypes" + assert a.dtype in [torch.int8, torch.float8_e4m3fn] + + FLOAT = (a.dtype == torch.float8_e4m3fn) + + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float32 if FLOAT else torch.int32) + + def grid(META): + return (triton.cdiv(N, META['BLOCK_SIZE_N']) , triton.cdiv(M, META['BLOCK_SIZE_M']), ) + __mm_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + QUANTIZED_M = M // 64, + FLOAT = FLOAT + ) + return c diff --git a/requirements-cuda.txt b/requirements-cuda.txt index f54914afc..c60082db5 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -1,9 +1,9 @@ # pytorch --extra-index-url https://download.pytorch.org/whl/cu128 -torch==2.7.1+cu128 -torchvision==0.22.1+cu128 +torch==2.8.0+cu128 +torchvision==0.23.0+cu128 onnxruntime-gpu==1.22.0 -nvidia-nccl-cu12==2.26.2; sys_platform == "linux" #TODO upgrade to support RTX 5090 multi-GPU +nvidia-nccl-cu12==2.27.3; sys_platform == "linux" # optimizers bitsandbytes==0.46.0 # bitsandbytes for 8-bit optimizers and weight quantization diff --git a/requirements-global.txt b/requirements-global.txt index 521df52e8..2111386eb 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -35,7 +35,7 @@ pooch==1.8.2 open-clip-torch==2.32.0 # data loader --e git+https://github.com/Nerogar/mgds.git@50a2394#egg=mgds +-e git+https://github.com/dxqb/mgds.git@gpu#egg=mgds # optimizers dadaptation==3.2 # dadaptation optimizers diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 33c92158f..936872a4b 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -1,3 +1,6 @@ +# Note: AMD requirements might be outdated. If you can provide information about running OneTrainer on AMD, +# please open an issue or pull request on github + # pytorch --extra-index-url https://download.pytorch.org/whl/rocm6.3 torch==2.7.1+rocm6.3 From bbf6eb8e80e2d4fba83946b599c9fdc41f87c3cc Mon Sep 17 00:00:00 2001 From: dxqb Date: Sun, 5 Oct 2025 12:03:33 +0200 Subject: [PATCH 02/54] - fix cache dir - increase offloading alignment to 16 - disable grads for SVDQuant to save vram --- modules/module/quantized/LinearSVD.py | 2 ++ modules/util/LayerOffloadConductor.py | 14 +++++++------- modules/util/quantization_util.py | 4 +--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/modules/module/quantized/LinearSVD.py b/modules/module/quantized/LinearSVD.py index 5e612e8b5..200d01b5e 100644 --- a/modules/module/quantized/LinearSVD.py +++ b/modules/module/quantized/LinearSVD.py @@ -69,6 +69,8 @@ def quantize(self, rank: int, svd_dtype: torch.dtype, device: torch.device | Non self.svd_down = Vh_r.to(svd_dtype) self.svd_up = (U_r * S_r.unsqueeze(0)).to(svd_dtype) + self.svd_down.requires_grad_(False) + self.svd_up.requires_grad_(False) self.weight.data = (W - (self.svd_up @ self.svd_down)).to(dtype=self.weight.dtype, device=orig_device) super().quantize(device) diff --git a/modules/util/LayerOffloadConductor.py b/modules/util/LayerOffloadConductor.py index 43163cdce..79229271d 100644 --- a/modules/util/LayerOffloadConductor.py +++ b/modules/util/LayerOffloadConductor.py @@ -34,12 +34,12 @@ def clone_tensor_allocator(tensor: torch.Tensor) -> torch.Tensor: return tensor.clone() -def ceil_8(number: int) -> int: - return number + (8 - (number % 8)) % 8 +def ceil_16(number: int) -> int: + return number + (16 - (number % 16)) % 16 -def floor_8(number: int) -> int: - return number - (number % 8) +def floor_16(number: int) -> int: + return number - (number % 16) class StaticLayerTensorAllocator: @@ -69,7 +69,7 @@ def allocate_like(self, source_tensor: torch.Tensor) -> torch.Tensor: total_cache_bytes = cache_tensor_size * len(self.__layer_allocator.cache_tensors) if self.__allocate_forward: cache_tensor_index = self.__allocation_end // cache_tensor_size - cache_tensor_allocation_end = ceil_8(self.__allocation_end % cache_tensor_size) + cache_tensor_allocation_end = ceil_16(self.__allocation_end % cache_tensor_size) if cache_tensor_allocation_end + num_bytes > cache_tensor_size: # move to the start of the next cache tensor @@ -100,7 +100,7 @@ def allocate_like(self, source_tensor: torch.Tensor) -> torch.Tensor: cache_tensor_index = len(self.__layer_allocator.cache_tensors) - 1 cache_tensor_allocation_start = cache_tensor_size - new_allocation_start = floor_8(cache_tensor_allocation_start - num_bytes) + new_allocation_start = floor_16(cache_tensor_allocation_start - num_bytes) self.__layer_allocator.ensure_allocation(cache_tensor_index) cache_tensor = self.__layer_allocator.cache_tensors[cache_tensor_index] allocated_tensor = cache_tensor[new_allocation_start:new_allocation_start + num_bytes] @@ -284,7 +284,7 @@ def allocate_like(self, source_tensor: torch.Tensor) -> torch.Tensor: cache_tensor = self.__cache_tensors[self.__current_cache_tensor] allocated_tensor = \ cache_tensor[self.__current_cache_tensor_offset:self.__current_cache_tensor_offset + num_bytes] - self.__current_cache_tensor_offset += ceil_8(num_bytes) + self.__current_cache_tensor_offset += ceil_16(num_bytes) return allocated_tensor.view(dtype=source_tensor.dtype).view(size=source_tensor.shape) diff --git a/modules/util/quantization_util.py b/modules/util/quantization_util.py index 168313ea9..3c827ad92 100644 --- a/modules/util/quantization_util.py +++ b/modules/util/quantization_util.py @@ -1,6 +1,5 @@ import os from collections.abc import Callable -from contextlib import suppress from functools import partial from modules.module.quantized.LinearFp8 import LinearFp8 @@ -150,8 +149,7 @@ def is_quantized_parameter( def quantize_layers(module: nn.Module, device: torch.device, train_dtype: DataType, config: TrainConfig): if module is not None: cache_dir = config.cache_dir + "/quantization" - with suppress(FileExistsError): - os.mkdir(cache_dir) + os.makedirs(cache_dir, exist_ok=True) child_modules = list(module.modules()) for child_module in tqdm(child_modules, desc="Quantizing model weights", total=len(child_modules), delay=5, smoothing=0.1): if isinstance(child_module, QuantizedModuleMixin): From 2676d330dffe1431d115b9874c3801de0803f174 Mon Sep 17 00:00:00 2001 From: dxqb Date: Sun, 5 Oct 2025 16:23:01 +0200 Subject: [PATCH 03/54] hide checkpoints from LoRA saving --- modules/modelSaver/mixin/LoRASaverMixin.py | 6 ++++-- modules/module/quantized/LinearSVD.py | 2 +- modules/util/checkpointing_util.py | 17 +++++++---------- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/modules/modelSaver/mixin/LoRASaverMixin.py b/modules/modelSaver/mixin/LoRASaverMixin.py index 904997a4f..fc3f37440 100644 --- a/modules/modelSaver/mixin/LoRASaverMixin.py +++ b/modules/modelSaver/mixin/LoRASaverMixin.py @@ -41,7 +41,8 @@ def __save_safetensors( destination: str, dtype: torch.dtype | None, ): - state_dict = self._get_state_dict(model) + state_dict_checkpointed = self._get_state_dict(model) + state_dict = {k.replace(".checkpoint.", "."): v for k, v in state_dict_checkpointed.items()} save_state_dict = self._convert_state_dict_dtype(state_dict, dtype) key_sets = self._get_convert_key_sets(model) @@ -57,7 +58,8 @@ def __save_legacy_safetensors( destination: str, dtype: torch.dtype | None, ): - state_dict = self._get_state_dict(model) + state_dict_checkpointed = self._get_state_dict(model) + state_dict = {k.replace(".checkpoint.", "."): v for k, v in state_dict_checkpointed.items()} save_state_dict = self._convert_state_dict_dtype(state_dict, dtype) key_sets = self._get_convert_key_sets(model) diff --git a/modules/module/quantized/LinearSVD.py b/modules/module/quantized/LinearSVD.py index 200d01b5e..76966f6be 100644 --- a/modules/module/quantized/LinearSVD.py +++ b/modules/module/quantized/LinearSVD.py @@ -67,7 +67,7 @@ def quantize(self, rank: int, svd_dtype: torch.dtype, device: torch.device | Non S_r = S[:rank] Vh_r = Vh[:rank, :] - self.svd_down = Vh_r.to(svd_dtype) + self.svd_down = Vh_r.clone().to(svd_dtype) self.svd_up = (U_r * S_r.unsqueeze(0)).to(svd_dtype) self.svd_down.requires_grad_(False) self.svd_up.requires_grad_(False) diff --git a/modules/util/checkpointing_util.py b/modules/util/checkpointing_util.py index ac5e17a40..0a25717b5 100644 --- a/modules/util/checkpointing_util.py +++ b/modules/util/checkpointing_util.py @@ -50,13 +50,12 @@ def _generate_call_index() -> int: class CheckpointLayer(torch.nn.Module): def __init__(self, orig: nn.Module, train_device: torch.device): super().__init__() - self.orig = orig + self.checkpoint = orig # dummy tensor that requires grad is needed for checkpointing to work when training a LoRA self.dummy = torch.zeros((1,), device=train_device, requires_grad=True) - #self.orig.compile(fullgraph=True) def __checkpointing_forward(self, dummy: torch.Tensor, *args, **kwargs): - return self.orig(*args, **kwargs) + return self.checkpoint(*args, **kwargs) def forward(self, *args, **kwargs): if torch.is_grad_enabled(): @@ -68,13 +67,12 @@ def forward(self, *args, **kwargs): use_reentrant=False ) else: - return self.orig(*args, **kwargs) - + return self.checkpoint(*args, **kwargs) class OffloadCheckpointLayer(torch.nn.Module): def __init__(self, orig: nn.Module, train_device: torch.device, conductor: LayerOffloadConductor, layer_index: int): super().__init__() - self.orig = orig + self.checkpoint = orig self.dummy = torch.zeros((1,), device=train_device, requires_grad=True) self.conductor = conductor self.layer_index = layer_index @@ -85,7 +83,7 @@ def __checkpointing_forward(self, dummy: torch.Tensor, call_id: int, *args): self.conductor.start_forward(True) args = self.conductor.before_layer(self.layer_index, call_id, args) - output = self.orig(*args) + output = self.checkpoint(*args) self.conductor.after_layer(self.layer_index, call_id, args) # make sure at least one of the output tensors has a grad_fn so the output of the checkpoint has a grad_fn @@ -99,7 +97,7 @@ def __checkpointing_forward(self, dummy: torch.Tensor, call_id: int, *args): def forward(self, *args, **kwargs): call_id = _generate_call_index() - args = _kwargs_to_args(self.orig.forward, args, kwargs) + args = _kwargs_to_args(self.checkpoint.forward, args, kwargs) if torch.is_grad_enabled(): return checkpoint( @@ -114,11 +112,10 @@ def forward(self, *args, **kwargs): self.conductor.start_forward(False) args = self.conductor.before_layer(self.layer_index, call_id, args) - output = self.orig(*args) + output = self.checkpoint(*args) self.conductor.after_layer(self.layer_index, call_id, args) return output - def create_checkpoint( orig_module: nn.Module, train_device: torch.device, From dc289fa27901f00e8ab613da705be51645faa57e Mon Sep 17 00:00:00 2001 From: dxqb Date: Mon, 6 Oct 2025 09:58:50 +0200 Subject: [PATCH 04/54] fix buffer registration --- modules/module/quantized/LinearSVD.py | 14 ++++++--- modules/module/quantized/LinearW8A8.py | 43 +++++++++++++------------- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/modules/module/quantized/LinearSVD.py b/modules/module/quantized/LinearSVD.py index 76966f6be..fc65ebf39 100644 --- a/modules/module/quantized/LinearSVD.py +++ b/modules/module/quantized/LinearSVD.py @@ -21,6 +21,9 @@ def _get_tensor_hash(t: torch.Tensor) -> str: hash_obj = hashlib.sha256(tensor_bytes) return hash_obj.hexdigest() + +log_obj = None + def make_svd_linear(linear_class): class LinearSVD( linear_class, @@ -67,15 +70,16 @@ def quantize(self, rank: int, svd_dtype: torch.dtype, device: torch.device | Non S_r = S[:rank] Vh_r = Vh[:rank, :] - self.svd_down = Vh_r.clone().to(svd_dtype) - self.svd_up = (U_r * S_r.unsqueeze(0)).to(svd_dtype) - self.svd_down.requires_grad_(False) - self.svd_up.requires_grad_(False) + svd_down = Vh_r.clone().to(svd_dtype) + svd_up = (U_r * S_r.unsqueeze(0)).to(svd_dtype) + self.register_buffer("svd_up", svd_up) + self.register_buffer("svd_down", svd_down) - self.weight.data = (W - (self.svd_up @ self.svd_down)).to(dtype=self.weight.dtype, device=orig_device) + self.weight.data = (W - (svd_up @ svd_down)).to(dtype=self.weight.dtype, device=orig_device) super().quantize(device) def forward(self, x: torch.Tensor) -> torch.Tensor: + assert not self.svd_down.requires_grad and not self.svd_up.requires_grad return ((x @ self.svd_down.T) @ self.svd_up.T).to(x.dtype) + super().forward(x) return LinearSVD diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index 623e391e8..22ad9928e 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -107,56 +107,55 @@ class LinearW8A8( def __init__(self, dtype, compute_dtype, *args, **kwargs): super().__init__(*args, **kwargs) - self.is_quantized = False assert dtype in [torch.int8, torch.float8_e4m3fn] self._dtype = dtype self._compute_dtype = compute_dtype - self._scale = torch.tensor(1.0, dtype=torch.float32) - self.register_buffer("scale", self._scale) + self.register_buffer("scale", None) def original_weight_shape(self) -> tuple[int, ...]: return self.weight.shape def unquantized_weight(self, dtype: torch.dtype, device: torch.device) -> torch.Tensor: - if self._scale is not None: - return unquantize(self.weight.detach(), self._scale, self._compute_dtype).to(dtype) + if self.scale is not None: + return unquantize(self.weight.detach(), self.scale, self._compute_dtype).to(dtype) else: return self.weight.detach().to(dtype) def quantize(self, device: torch.device | None = None, **kwargs): - if self.is_quantized: + if self.scale is not None: return - self.is_quantized = True - self.weight.requires_grad_(False) - weight = self.weight.data + weight = self.weight.detach() orig_device = weight.device - if weight.dtype != self._dtype: - if device is not None: - weight = weight.to(device=device) + if device is not None: + weight = weight.to(device=device) - if self._dtype == torch.int8: - weight, self._scale = quantize_int8_tensorwise(weight) - else: - weight, self._scale = quantize_fp8_tensorwise(weight) + if self._dtype == torch.int8: + weight, scale = quantize_int8_tensorwise(weight) + else: + weight, scale = quantize_fp8_tensorwise(weight) + + if device is not None: + weight = weight.to(device=orig_device) + scale = scale.to(device=orig_device) - if device is not None: - weight = weight.to(device=orig_device) - self.weight.data = weight + self.weight = nn.Parameter(weight, requires_grad=False) + self.register_buffer("scale", scale) def forward(self, x_orig: torch.Tensor) -> torch.Tensor: + assert not self.weight.requires_grad x = x_orig.to(self._compute_dtype).reshape(-1, x_orig.shape[-1]) if x.shape[0] > 16: if self._dtype == torch.int8: - y = LinearInt8Function.apply(x, self.weight, self._scale, self.bias) + y = LinearInt8Function.apply(x, self.weight, self.scale, self.bias) else: - y = LinearFp8Function.apply(x, self.weight, self._scale, self.bias) + y = LinearFp8Function.apply(x, self.weight, self.scale, self.bias) else: - w = unquantize(self.weight, self._scale, compute_dtype=self._compute_dtype) + w = unquantize(self.weight, self.scale, compute_dtype=self._compute_dtype) y = torch.nn.functional.linear(x, w, self.bias.to(self._compute_dtype)) assert y.dtype == self._compute_dtype From 73822b0df7eff57459a9cd480c5ec0c9dd27539a Mon Sep 17 00:00:00 2001 From: dxqb Date: Mon, 6 Oct 2025 10:01:23 +0200 Subject: [PATCH 05/54] fix buffer registration --- modules/module/quantized/LinearSVD.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/modules/module/quantized/LinearSVD.py b/modules/module/quantized/LinearSVD.py index fc65ebf39..7f3328a2f 100644 --- a/modules/module/quantized/LinearSVD.py +++ b/modules/module/quantized/LinearSVD.py @@ -21,9 +21,6 @@ def _get_tensor_hash(t: torch.Tensor) -> str: hash_obj = hashlib.sha256(tensor_bytes) return hash_obj.hexdigest() - -log_obj = None - def make_svd_linear(linear_class): class LinearSVD( linear_class, From 4821c9ef387a36c234d6000de4d870d73eafa460 Mon Sep 17 00:00:00 2001 From: dxqb Date: Mon, 13 Oct 2025 11:37:16 +0200 Subject: [PATCH 06/54] various --- .pre-commit-config.yaml | 2 +- modules/modelSaver/mixin/LoRASaverMixin.py | 6 +- modules/module/LoRAModule.py | 1 + modules/module/quantized/LinearFp8.py | 1 - modules/module/quantized/LinearSVD.py | 54 +++-- modules/module/quantized/LinearW8A8.py | 191 ++++++++++++------ .../quantized/mixin/QuantizedModuleMixin.py | 2 +- modules/ui/ModelTab.py | 2 +- modules/util/LayerOffloadConductor.py | 2 +- modules/util/checkpointing_util.py | 148 ++++++++++---- modules/util/config/TrainConfig.py | 4 +- modules/util/quantization_util.py | 13 +- requirements-cuda.txt | 1 + requirements-global.txt | 10 +- training_presets/#qwen LoRA 16GB.json | 2 +- training_presets/#qwen LoRA 24GB.json | 2 +- update.sh | 2 +- 17 files changed, 298 insertions(+), 145 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 589ded7f2..036bce7d1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: - id: check-yaml - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.13.2 + rev: v0.13.3 hooks: # Run the Ruff linter, but not the formatter. - id: ruff diff --git a/modules/modelSaver/mixin/LoRASaverMixin.py b/modules/modelSaver/mixin/LoRASaverMixin.py index fc3f37440..904997a4f 100644 --- a/modules/modelSaver/mixin/LoRASaverMixin.py +++ b/modules/modelSaver/mixin/LoRASaverMixin.py @@ -41,8 +41,7 @@ def __save_safetensors( destination: str, dtype: torch.dtype | None, ): - state_dict_checkpointed = self._get_state_dict(model) - state_dict = {k.replace(".checkpoint.", "."): v for k, v in state_dict_checkpointed.items()} + state_dict = self._get_state_dict(model) save_state_dict = self._convert_state_dict_dtype(state_dict, dtype) key_sets = self._get_convert_key_sets(model) @@ -58,8 +57,7 @@ def __save_legacy_safetensors( destination: str, dtype: torch.dtype | None, ): - state_dict_checkpointed = self._get_state_dict(model) - state_dict = {k.replace(".checkpoint.", "."): v for k, v in state_dict_checkpointed.items()} + state_dict = self._get_state_dict(model) save_state_dict = self._convert_state_dict_dtype(state_dict, dtype) key_sets = self._get_convert_key_sets(model) diff --git a/modules/module/LoRAModule.py b/modules/module/LoRAModule.py index c9c24777b..5f4328a03 100644 --- a/modules/module/LoRAModule.py +++ b/modules/module/LoRAModule.py @@ -485,6 +485,7 @@ def __create_modules(self, orig_module: nn.Module | None, config: TrainConfig) - unsuitable = [] for name, child_module in orig_module.named_modules(): + name = name.replace(".checkpoint.", ".") if not isinstance(child_module, Linear | Conv2d): unsuitable.append(name) continue diff --git a/modules/module/quantized/LinearFp8.py b/modules/module/quantized/LinearFp8.py index 9c693b629..5a54b2f38 100644 --- a/modules/module/quantized/LinearFp8.py +++ b/modules/module/quantized/LinearFp8.py @@ -19,7 +19,6 @@ def __init__(self, *args, **kwargs): self.fp8_dtype = torch.float8_e4m3fn self._scale = torch.tensor(1.0, dtype=torch.float) self.register_buffer("scale", self._scale) - self.compute_dtype = None def original_weight_shape(self) -> tuple[int, ...]: diff --git a/modules/module/quantized/LinearSVD.py b/modules/module/quantized/LinearSVD.py index 7f3328a2f..085c39374 100644 --- a/modules/module/quantized/LinearSVD.py +++ b/modules/module/quantized/LinearSVD.py @@ -1,4 +1,3 @@ -import hashlib from contextlib import suppress from modules.module.quantized.mixin.QuantizedLinearMixin import QuantizedLinearMixin @@ -16,10 +15,14 @@ def __init__(self, *args, **kwargs): def _get_tensor_hash(t: torch.Tensor) -> str: - tensor = t.detach().cpu().contiguous() - tensor_bytes = tensor.numpy().tobytes() - hash_obj = hashlib.sha256(tensor_bytes) - return hash_obj.hexdigest() + t = t.flatten().to(torch.float32) + vals = torch.stack([ + torch.sum(t), + torch.sum(t**2), + torch.sum(torch.sin(t)), + torch.sum(torch.cos(t)) + ]) + return vals.cpu().numpy().tobytes().hex() def make_svd_linear(linear_class): class LinearSVD( @@ -28,18 +31,23 @@ class LinearSVD( ): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.register_buffer("svd_up", None) - self.register_buffer("svd_down", None) + self.__svd_is_quantized = False + + #use parameters instead of buffer to allow offloading: + self.svd_up = torch.nn.Parameter(torch.empty(()), requires_grad=False) + self.svd_down = torch.nn.Parameter(torch.empty(()), requires_grad=False) def unquantized_weight(self, dtype: torch.dtype, device: torch.device) -> torch.Tensor: - if self.svd_up is None: - return super().unquantized_weight(dtype, device) - else: + if self.__svd_is_quantized: return (self.svd_up @ self.svd_down).to(dtype) + super().unquantized_weight(dtype, device) + else: + return super().unquantized_weight(dtype, device) - def quantize(self, rank: int, svd_dtype: torch.dtype, device: torch.device | None = None, cache_dir: str | None = None, max_cache_rank: int = 128): - if self.svd_up is not None: + @torch.no_grad() + def quantize(self, rank: int, svd_dtype: torch.dtype, device: torch.device | None = None, cache_dir: str | None = None, max_cache_rank: int = 128, **kwargs): + if self.__svd_is_quantized: return + self.__svd_is_quantized = True W = super().unquantized_weight(torch.float32, device) orig_device = W.device @@ -50,7 +58,7 @@ def quantize(self, rank: int, svd_dtype: torch.dtype, device: torch.device | Non if cache_dir is not None: filename = cache_dir + "/" + _get_tensor_hash(W) + ".pt" with suppress(FileNotFoundError): - U, S, Vh = torch.load(filename, map_location=device) + U, S, Vh = torch.load(filename, map_location=W.device) if U is None: #use full svd - torch.svd_lowrank is not reducing the quant range nearly as much: @@ -67,15 +75,23 @@ def quantize(self, rank: int, svd_dtype: torch.dtype, device: torch.device | Non S_r = S[:rank] Vh_r = Vh[:rank, :] - svd_down = Vh_r.clone().to(svd_dtype) - svd_up = (U_r * S_r.unsqueeze(0)).to(svd_dtype) - self.register_buffer("svd_up", svd_up) - self.register_buffer("svd_down", svd_down) + svd_down = Vh_r.clone().contiguous().to(svd_dtype) + svd_up = (U_r * S_r.unsqueeze(0)).clone().contiguous().to(svd_dtype) + weight = (W - (svd_up @ svd_down)).to(dtype=self.weight.dtype) + + if device is not None: + weight = weight.to(device=orig_device) + svd_up = svd_up.to(device=orig_device) + svd_down = svd_down.to(device=orig_device) - self.weight.data = (W - (svd_up @ svd_down)).to(dtype=self.weight.dtype, device=orig_device) - super().quantize(device) + self.requires_grad_(False) + self.svd_up = torch.nn.Parameter(svd_up, requires_grad=False) + self.svd_down = torch.nn.Parameter(svd_down, requires_grad=False) + self.weight.data = weight + super().quantize(device=device, **kwargs) def forward(self, x: torch.Tensor) -> torch.Tensor: + assert self.__svd_is_quantized assert not self.svd_down.requires_grad and not self.svd_up.requires_grad return ((x @ self.svd_down.T) @ self.svd_up.T).to(x.dtype) + super().forward(x) diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index 22ad9928e..87838d483 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -1,44 +1,65 @@ + from modules.module.quantized.mixin.QuantizedLinearMixin import QuantizedLinearMixin from modules.module.quantized.mixin.QuantizedModuleMixin import QuantizedModuleMixin from modules.util.triton_mm_8bit import mm_8bit as triton_mm_8bit import torch -from torch import nn +from torch import Tensor, nn + +def quantize_int8(x: Tensor, scale: float | Tensor) -> Tensor: + q = x.float().mul(1.0 / scale).round_().clamp_(-128.0, 127.0).to(torch.int8) + return q -def quantize_int8_tensorwise(x): +def quantize_int8_tensorwise_get_scale(x: Tensor) -> float: abs_max = x.abs().max() scale = (abs_max.float() / 127.0).clamp(min=1e-12) - q = x.float().mul_(1.0 / scale).round_().clamp_(-128.0, 127.0).to(torch.int8) - return q, scale + return scale +def quantize_int8_tensorwise(x: Tensor) -> tuple[Tensor, float]: + scale = quantize_int8_tensorwise_get_scale(x) + q = quantize_int8(x, scale) + return q, scale -def quantize_int8_channelwise(x, dim=-1): - abs_max = x.abs().amax(dim=dim, keepdim=True) +def quantize_int8_tokenwise_get_scale(x: Tensor) -> Tensor: + abs_max = x.abs().amax(dim=-1, keepdim=True) scale = (abs_max.float() / 127.0).clamp(min=1e-12) - q = x.float().mul_(1.0 / scale).round_().clamp_(-128.0, 127.0).to(torch.int8) + return scale + +def quantize_int8_tokenwise(x: Tensor) -> tuple[Tensor, Tensor]: + scale = quantize_int8_tokenwise_get_scale(x) + q = quantize_int8(x, scale) return q, scale +def quantize_fp8(x: Tensor, scale: float | Tensor) -> Tensor: + q = x.float().mul(1.0 / scale).clamp_(-448.0, 448.0).to(torch.float8_e4m3fn) + return q -def quantize_fp8_tensorwise(x): +def quantize_fp8_tensorwise_get_scale(x: Tensor) -> float: abs_max = x.abs().max() scale = (abs_max.float() / 448.0).clamp(min=1e-12) - q = x.float().mul_(1.0 / scale).round().clamp(-448.0, 448.0).to(torch.float8_e4m3fn) - return q, scale - + return scale -def quantize_fp8_channelwise(x, dim=-1): - abs_max = x.abs().amax(dim=dim, keepdim=True) +def quantize_fp8_tokenwise_get_scale(x: Tensor) -> Tensor: + abs_max = x.abs().amax(dim=-1, keepdim=True) scale = (abs_max.float() / 448.0).clamp(min=1e-12) - q = x.float().mul_(1.0 / scale).round_().clamp_(-448.0, 448.0).to(torch.float8_e4m3fn) + return scale + +def quantize_fp8_tensorwise(x: Tensor) -> tuple[Tensor, float]: + scale = quantize_fp8_tensorwise_get_scale(x) + q = quantize_fp8(x, scale) return q, scale +def quantize_fp8_tokenwise(x: Tensor) -> tuple[Tensor, Tensor]: + scale = quantize_fp8_tokenwise_get_scale(x) + q = quantize_fp8(x, scale) + return q, scale -def unquantize(q, scale, compute_dtype): - return q.to(compute_dtype).mul_(scale) +def unquantize(q: Tensor, scale: float | Tensor, compute_dtype: torch.dtype) -> Tensor: + return q.to(compute_dtype) * scale.to(compute_dtype) -def int8_forward_channelwise(x, weight, weight_scale, bias=None): - x_8, x_scale = quantize_int8_channelwise(x) +def int8_forward_tokenwise(x: Tensor, weight: float | Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: + x_8, x_scale = quantize_int8_tokenwise(x) res = torch._int_mm(x_8, weight.T) res_scaled = res.to(x.dtype).mul_(weight_scale * x_scale) if bias is not None: @@ -46,65 +67,63 @@ def int8_forward_channelwise(x, weight, weight_scale, bias=None): return res_scaled -def fp8_forward_channelwise(x, weight, weight_scale, bias=None): - x_8, x_scale = quantize_fp8_channelwise(x) +def fp8_forward_tokenwise(x: Tensor, weight: float | Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: + x_8, x_scale = quantize_fp8_tokenwise(x) one = torch.ones(1, device=x.device) - res = torch._scaled_mm(x_8, weight.T, scale_a=one, scale_b=weight_scale.float(), out_dtype=x.dtype) + #res = torch._scaled_mm(x_8, weight.T, scale_a=one, scale_b=weight_scale.float(), out_dtype=x.dtype) #FIXME TODO test difference + res = torch._scaled_mm(x_8, weight.T, scale_a=one, scale_b=weight_scale.float(), out_dtype=torch.float32) res_scaled = res.mul_(x_scale) #much faster than scaled by _scaled_mm + + res_scaled = res_scaled.to(x.dtype) #FIXME if bias is not None: res_scaled.add_(bias.to(x.dtype)) return res_scaled -def apply_scale(mm_res, weight_scale, x_scale, compute_dtype): - return mm_res.to(compute_dtype).mul_(weight_scale * x_scale) - -def int8_backward_W_tensorwise_A_channelwise(x, weight, weight_scale): - x_8, x_scale = quantize_int8_channelwise(x) +def int8_backward_W_tensorwise_A_columnwise(x: Tensor, weight: Tensor, weight_scale: float) -> Tensor: + x_8, x_scale = quantize_int8_tokenwise(x) mm_res = triton_mm_8bit(x_8, weight) - return apply_scale(mm_res, weight_scale, x_scale, compute_dtype=x.dtype) + return mm_res.to(x.dtype).mul_(weight_scale * x_scale) -def fp8_backward_W_tensorwise_A_channelwise(x, weight, weight_scale): - x_8, x_scale = quantize_fp8_channelwise(x) +def fp8_backward_W_tensorwise_A_columnwise(x: Tensor, weight: Tensor, weight_scale: float) -> Tensor: + x_8, x_scale = quantize_fp8_tokenwise(x) mm_res = triton_mm_8bit(x_8, weight) - return apply_scale(mm_res, weight_scale, x_scale, compute_dtype=x.dtype) + return mm_res.to(x.dtype).mul_(weight_scale * x_scale) class LinearInt8Function(torch.autograd.Function): @staticmethod - def forward(ctx, x, weight, weight_scale, bias): + def forward(ctx, x: Tensor, weight: Tensor, weight_scale: float, bias: Tensor | None) -> Tensor: ctx.save_for_backward(weight, weight_scale) - return int8_forward_channelwise(x, weight, weight_scale, bias) + return int8_forward_tokenwise(x, weight, weight_scale, bias) @staticmethod - def backward(ctx, x): + def backward(ctx, x: Tensor): if ctx.needs_input_grad != (True, False, False, False): raise NotImplementedError("Int A8W8 cannot be used for full finetuning") weight, weight_scale = ctx.saved_tensors - return int8_backward_W_tensorwise_A_channelwise(x, weight, weight_scale), None, None, None + return int8_backward_W_tensorwise_A_columnwise(x, weight, weight_scale), None, None, None class LinearFp8Function(torch.autograd.Function): @staticmethod - def forward(ctx, x, weight, weight_scale, bias): + def forward(ctx, x: Tensor, weight: Tensor, weight_scale: float, bias: Tensor | None) -> Tensor: ctx.save_for_backward(weight, weight_scale) - return fp8_forward_channelwise(x.bfloat16(), weight, weight_scale, bias).bfloat16() + return fp8_forward_tokenwise(x.bfloat16(), weight, weight_scale, bias).bfloat16() @staticmethod - def backward(ctx, x): + def backward(ctx, x: Tensor): if ctx.needs_input_grad != (True, False, False, False): - raise NotImplementedError("Float W8A8 cannot be used for full finetuning") + raise NotImplementedError("Float A8W8 cannot be used for full finetuning") weight, weight_scale = ctx.saved_tensors - return fp8_backward_W_tensorwise_A_channelwise(x, weight, weight_scale), None, None, None + return fp8_backward_W_tensorwise_A_columnwise(x, weight, weight_scale), None, None, None class LinearW8A8( nn.Linear, QuantizedModuleMixin, QuantizedLinearMixin, ): - is_quantized: bool - def __init__(self, dtype, compute_dtype, *args, **kwargs): super().__init__(*args, **kwargs) @@ -112,27 +131,25 @@ def __init__(self, dtype, compute_dtype, *args, **kwargs): self._dtype = dtype self._compute_dtype = compute_dtype - self.register_buffer("scale", None) - + self.__is_quantized = False + self.register_buffer("scale", torch.tensor(1.0, dtype=torch.float32)) def original_weight_shape(self) -> tuple[int, ...]: return self.weight.shape def unquantized_weight(self, dtype: torch.dtype, device: torch.device) -> torch.Tensor: - if self.scale is not None: - return unquantize(self.weight.detach(), self.scale, self._compute_dtype).to(dtype) - else: - return self.weight.detach().to(dtype) + return unquantize(self.weight.detach(), self.scale, self._compute_dtype).to(dtype) + @torch.no_grad() def quantize(self, device: torch.device | None = None, **kwargs): - if self.scale is not None: + if self.__is_quantized: return + self.__is_quantized = True weight = self.weight.detach() orig_device = weight.device if device is not None: weight = weight.to(device=device) - if self._dtype == torch.int8: weight, scale = quantize_int8_tensorwise(weight) else: @@ -140,13 +157,20 @@ def quantize(self, device: torch.device | None = None, **kwargs): if device is not None: weight = weight.to(device=orig_device) - scale = scale.to(device=orig_device) - self.weight = nn.Parameter(weight, requires_grad=False) - self.register_buffer("scale", scale) + self.requires_grad_(False) + self.weight.data = weight + + self.scale.copy_(scale) def forward(self, x_orig: torch.Tensor) -> torch.Tensor: + #calculate validation loss using 16 bit math: + #if not self.training: + # w = unquantize(self.weight, self.scale, compute_dtype=x_orig.dtype) + # return torch.nn.functional.linear(x_orig, w, self.bias) + assert not self.weight.requires_grad + assert self.__is_quantized x = x_orig.to(self._compute_dtype).reshape(-1, x_orig.shape[-1]) if x.shape[0] > 16: @@ -159,9 +183,56 @@ def forward(self, x_orig: torch.Tensor) -> torch.Tensor: y = torch.nn.functional.linear(x, w, self.bias.to(self._compute_dtype)) assert y.dtype == self._compute_dtype - return y.reshape(x_orig.shape[:-1] + (self.weight.shape[0], )) + return y.reshape(x_orig.shape[:-1] + (y.shape[-1], )) + +class LinearRequantInt8Function(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, weight_orig: Tensor, bias: Tensor | None) -> Tensor: + weight, weight_scale = quantize_int8_tensorwise(weight_orig) + ctx.save_for_backward(weight, weight_scale) + return int8_forward_tokenwise(x, weight, weight_scale, bias) + + @staticmethod + def backward(ctx, x: Tensor): + if ctx.needs_input_grad != (True, False, False): + raise NotImplementedError("NF4 cannot be used for full finetuning") + weight, weight_scale = ctx.saved_tensors + return int8_backward_W_tensorwise_A_columnwise(x, weight, weight_scale), None, None, None + + +def make_requant(linear_class): + class LinearRequantA8(linear_class): + def __init__(self, dtype, compute_dtype, *args, **kwargs): + super().__init__(*args, **kwargs) + + assert dtype in [torch.int8, torch.float8_e4m3fn] + self._dtype = dtype + self._compute_dtype = compute_dtype + + def quantize(self, device: torch.device | None = None, **kwargs): + super().quantize(device, **kwargs) + + def forward(self, x_orig: torch.Tensor) -> torch.Tensor: + #calculate validation loss using 16 bit math: + #if not self.training: + # w = super().unquantized_weight(dtype=self._compute_dtype, device=x_orig.device) + # return torch.nn.functional.linear(x_orig, w, self.bias) + + assert not self.weight.requires_grad + x = x_orig.to(self._compute_dtype).reshape(-1, x_orig.shape[-1]) + w = super().unquantized_weight(dtype=self._compute_dtype, device=x.device) + if x.shape[0] > 16: + if self._dtype == torch.int8: + y = LinearRequantInt8Function.apply(x, w, self.bias) + else: + raise NotImplementedError + else: + y = torch.nn.functional.linear(x, w, self.bias.to(self._compute_dtype)) + assert y.dtype == self._compute_dtype + return y.reshape(x_orig.shape[:-1] + (y.shape[-1], )) + return LinearRequantA8 def run_benchmark(fn, desc, steps=10000, warmup=500): @@ -175,9 +246,7 @@ def run_benchmark(fn, desc, steps=10000, warmup=500): @torch.no_grad() -def benchmark_int8(m, k, n, device = "cuda"): - device = "cuda" - +def benchmark_int8(m, k, n, device = 'cuda'): x = torch.randn(m,k, device=device, dtype=torch.bfloat16) x_8 = torch.ones (m,k, device=device, dtype=torch.int8) y = torch.randn(m,n, device=device, dtype=torch.bfloat16) @@ -193,12 +262,12 @@ def torch_backward(a, b): run_benchmark(lambda: torch_backward(y_8, w_8), "torch mm backward int8") run_benchmark(lambda: triton_mm_8bit(y_8, w_8), "triton mm backward int8") - run_benchmark(lambda: int8_forward_channelwise(x, w_8, w_scale), "torch forward int") - run_benchmark(lambda: int8_backward_W_tensorwise_A_channelwise(y, w_8, w_scale), "triton backward int") + run_benchmark(lambda: int8_forward_tokenwise(x, w_8, w_scale), "torch forward int") + run_benchmark(lambda: int8_backward_W_tensorwise_A_columnwise(y, w_8, w_scale), "triton backward int") @torch.no_grad() -def benchmark_fp8(m, k, n, device = "cuda"): +def benchmark_fp8(m, k, n, device = 'cuda'): x = torch.randn(m,k, device=device, dtype=torch.bfloat16) x_8 = torch.ones (m,k, device=device, dtype=torch.float8_e4m3fn) y = torch.randn(m,n, device=device, dtype=torch.bfloat16) @@ -213,8 +282,8 @@ def torch_backward(a, b): torch._scaled_mm(a, b.T.contiguous().T, out_dtype=torch.bfloat16, scale_a=one_scale.float(), scale_b=w_scale.float()) run_benchmark(lambda: torch_backward(y_8, w_8), "torch mm backward fp8") run_benchmark(lambda: triton_mm_8bit(y_8, w_8), "triton mm backward fp8") - run_benchmark(lambda: fp8_forward_channelwise(x, w_8, w_scale), "torch forward fp8") - run_benchmark(lambda: fp8_backward_W_tensorwise_A_channelwise(y, w_8, w_scale), "triton backward fp8") + run_benchmark(lambda: fp8_forward_tokenwise(x, w_8, w_scale), "torch forward fp8") + run_benchmark(lambda: fp8_backward_W_tensorwise_A_columnwise(y, w_8, w_scale), "triton backward fp8") if __name__ == "__main__": diff --git a/modules/module/quantized/mixin/QuantizedModuleMixin.py b/modules/module/quantized/mixin/QuantizedModuleMixin.py index 54904f59e..c110fd7ef 100644 --- a/modules/module/quantized/mixin/QuantizedModuleMixin.py +++ b/modules/module/quantized/mixin/QuantizedModuleMixin.py @@ -5,5 +5,5 @@ class QuantizedModuleMixin(metaclass=ABCMeta): @abstractmethod - def quantize(self, device: torch.device | None = None): + def quantize(self, device: torch.device | None = None, **kwargs): pass diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index c938382bb..2407a1bcf 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -274,7 +274,7 @@ def __create_dtype_options(self, include_none: bool=True, include_svd: bool=Fals ("float8 (W8) SVDQuant", DataType.FLOAT_8_SVD), ("float W8A8 SVDQuant", DataType.FLOAT_W8A8_SVD), ("int W8A8 SVDQuant", DataType.INT_W8A8_SVD), - ("nfloat4 SVDQuant", DataType.NFLOAT_4_SVD), + ("nf4 -> int W8A8 SVD", DataType.NFLOAT_4_SVD), ] if include_none: diff --git a/modules/util/LayerOffloadConductor.py b/modules/util/LayerOffloadConductor.py index 79229271d..70511659d 100644 --- a/modules/util/LayerOffloadConductor.py +++ b/modules/util/LayerOffloadConductor.py @@ -793,7 +793,7 @@ def __module_to_device_except_layers( sub_module_parameters = set(sum([list(x.parameters()) for x in self.__layers], [])) def convert(t): - if t in sub_module_parameters: + if t in sub_module_parameters or t.is_meta: return t return t.to(device=device) diff --git a/modules/util/checkpointing_util.py b/modules/util/checkpointing_util.py index 0a25717b5..5a730dea1 100644 --- a/modules/util/checkpointing_util.py +++ b/modules/util/checkpointing_util.py @@ -10,6 +10,24 @@ from torch import nn from torch.utils.checkpoint import checkpoint +from diffusers.models.attention import BasicTransformerBlock, JointTransformerBlock +from diffusers.models.transformers.sana_transformer import SanaTransformerBlock +from diffusers.models.transformers.transformer_hidream_image import ( + HiDreamImageSingleTransformerBlock, + HiDreamImageTransformerBlock, +) +from diffusers.models.transformers.transformer_hunyuan_video import ( + HunyuanVideoIndividualTokenRefinerBlock, + HunyuanVideoSingleTransformerBlock, + HunyuanVideoTransformerBlock, +) +from diffusers.models.unets.unet_stable_cascade import SDCascadeAttnBlock, SDCascadeResBlock, SDCascadeTimestepBlock +from transformers.models.clip.modeling_clip import CLIPEncoderLayer +from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLDecoderLayer +from transformers.models.t5.modeling_t5 import T5Block + torch._dynamo.config.cache_size_limit = 8192 def _kwargs_to_args(fun: Callable, args: tuple[Any, ...], kwargs: dict[str, Any]) -> tuple[Any, ...]: @@ -47,15 +65,24 @@ def _generate_call_index() -> int: return __current_call_index -class CheckpointLayer(torch.nn.Module): - def __init__(self, orig: nn.Module, train_device: torch.device): +class BaseCheckpointLayer(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class CheckpointLayer(BaseCheckpointLayer): + def __init__(self, orig_module: nn.Module, orig_forward, train_device: torch.device): super().__init__() - self.checkpoint = orig + + assert (orig_module is None or orig_forward is None) and not (orig_module is None and orig_forward is None) + self.checkpoint = orig_module + self.orig_forward = orig_forward + # dummy tensor that requires grad is needed for checkpointing to work when training a LoRA self.dummy = torch.zeros((1,), device=train_device, requires_grad=True) def __checkpointing_forward(self, dummy: torch.Tensor, *args, **kwargs): - return self.checkpoint(*args, **kwargs) + return self.orig_forward(*args, **kwargs) if self.checkpoint is None else self.checkpoint(*args, **kwargs) def forward(self, *args, **kwargs): if torch.is_grad_enabled(): @@ -67,12 +94,16 @@ def forward(self, *args, **kwargs): use_reentrant=False ) else: - return self.checkpoint(*args, **kwargs) + return self.orig_forward(*args, **kwargs) if self.checkpoint is None else self.checkpoint(*args, **kwargs) -class OffloadCheckpointLayer(torch.nn.Module): - def __init__(self, orig: nn.Module, train_device: torch.device, conductor: LayerOffloadConductor, layer_index: int): +class OffloadCheckpointLayer(BaseCheckpointLayer): + def __init__(self, orig_module: nn.Module, orig_forward, train_device: torch.device, conductor: LayerOffloadConductor, layer_index: int): super().__init__() - self.checkpoint = orig + + assert (orig_module is None or orig_forward is None) and not (orig_module is None and orig_forward is None) + self.checkpoint = orig_module + self.orig_forward = orig_forward + self.dummy = torch.zeros((1,), device=train_device, requires_grad=True) self.conductor = conductor self.layer_index = layer_index @@ -83,7 +114,8 @@ def __checkpointing_forward(self, dummy: torch.Tensor, call_id: int, *args): self.conductor.start_forward(True) args = self.conductor.before_layer(self.layer_index, call_id, args) - output = self.checkpoint(*args) + output = self.orig_forward(*args) if self.checkpoint is None else self.checkpoint(*args) + self.conductor.after_layer(self.layer_index, call_id, args) # make sure at least one of the output tensors has a grad_fn so the output of the checkpoint has a grad_fn @@ -97,7 +129,7 @@ def __checkpointing_forward(self, dummy: torch.Tensor, call_id: int, *args): def forward(self, *args, **kwargs): call_id = _generate_call_index() - args = _kwargs_to_args(self.checkpoint.forward, args, kwargs) + args = _kwargs_to_args(self.orig_forward if self.checkpoint is None else self.checkpoint.forward, args, kwargs) if torch.is_grad_enabled(): return checkpoint( @@ -112,7 +144,7 @@ def forward(self, *args, **kwargs): self.conductor.start_forward(False) args = self.conductor.before_layer(self.layer_index, call_id, args) - output = self.checkpoint(*args) + output = self.orig_forward(*args) if self.checkpoint is None else self.checkpoint(*args) self.conductor.after_layer(self.layer_index, call_id, args) return output @@ -123,7 +155,6 @@ def create_checkpoint( conductor: LayerOffloadConductor | None = None, layer_index: int = 0, compile: bool = False, - enabled: bool = True, ) -> Callable: if include_from_offload_param_names is None: include_from_offload_param_names = [] @@ -133,16 +164,26 @@ def create_checkpoint( conductor.add_layer(orig_module, included_offload_param_indices) if conductor is not None and conductor.offload_activated(): - layer = OffloadCheckpointLayer(orig_module, train_device, conductor, layer_index) if compile: + layer = OffloadCheckpointLayer(orig_module=orig_module, orig_forward=None, train_device=train_device, conductor=conductor, layer_index=layer_index) #don't compile the checkpointing layer - offloading cannot be compiled: orig_module.compile(fullgraph=True) + return layer + else: + #only patch forward() if possible. Inserting layers is necessary for torch.compile, but causes issues with at least 1 text encoder model + layer = OffloadCheckpointLayer(orig_module=None, orig_forward=orig_module.forward, train_device=train_device, conductor=conductor, layer_index=layer_index) + orig_module.forward = layer.forward + return orig_module else: - layer = CheckpointLayer(orig_module, train_device) if enabled else orig_module if compile: + layer = CheckpointLayer(orig_module=orig_module, orig_forward=None, train_device=train_device) #do compile the checkpointing layer - slightly faster layer.compile(fullgraph=True) - return layer + return layer + else: + layer = CheckpointLayer(orig_module=None, orig_forward=orig_module.forward, train_device=train_device) + orig_module.forward = layer.forward + return orig_module def _create_checkpoints_for_module_list( module_list: nn.ModuleList, @@ -154,6 +195,8 @@ def _create_checkpoints_for_module_list( ) -> int: for i, layer in enumerate(module_list): + if isinstance(module_list[i], BaseCheckpointLayer): + continue module_list[i] = create_checkpoint( layer, train_device, include_from_offload_param_names, @@ -162,6 +205,10 @@ def _create_checkpoints_for_module_list( layer_index += 1 return layer_index +def _remove_checkpoint_keys(module, state_dict, prefix, local_metadata): + for k in list(state_dict.keys()): + if ".checkpoint." in k: + state_dict[k.replace(".checkpoint.", ".")] = state_dict.pop(k) def enable_checkpointing( model: nn.Module, @@ -173,27 +220,43 @@ def enable_checkpointing( conductor = LayerOffloadConductor(model, config) layer_index = 0 - for module_list, param_names in lists: - layer_index = _create_checkpoints_for_module_list( - module_list, - param_names, - conductor if offload_enabled else None, - torch.device(config.train_device), - layer_index, - compile = compile, - ) - + for type_or_list, param_names in lists: + + assert isinstance(type_or_list, (nn.ModuleList, type)) + if isinstance(type_or_list, nn.ModuleList): + module_list = type_or_list + layer_index = _create_checkpoints_for_module_list( + module_list, + param_names, + conductor if offload_enabled else None, + torch.device(config.train_device), + layer_index, + compile = compile, + ) + else: + t = type_or_list + for child_module in model.modules(): + if isinstance(child_module, nn.ModuleList) and isinstance(child_module[0], t): + module_list = child_module + assert all(isinstance(m, t) for m in child_module) + layer_index = _create_checkpoints_for_module_list( + module_list, + param_names, + conductor if offload_enabled else None, + torch.device(config.train_device), + layer_index, + compile = compile, + ) + model._register_state_dict_hook(_remove_checkpoint_keys) return conductor - -#TODO test all models def enable_checkpointing_for_basic_transformer_blocks( model: nn.Module, config: TrainConfig, offload_enabled: bool, ) -> LayerOffloadConductor: return enable_checkpointing(model, config, config.compile, [ - (model.transformer_blocks, []), + (BasicTransformerBlock , []), ], offload_enabled = offload_enabled, ) @@ -203,7 +266,7 @@ def enable_checkpointing_for_clip_encoder_layers( config: TrainConfig, ): return enable_checkpointing(model, config, False, [ - (model.text_model.encoder.layers, []), # No activation offloading for text encoders, because the output might be taken from the middle of the network + (CLIPEncoderLayer, []), # No activation offloading for text encoders, because the output might be taken from the middle of the network ]) def enable_checkpointing_for_stable_cascade_blocks( @@ -211,8 +274,9 @@ def enable_checkpointing_for_stable_cascade_blocks( config: TrainConfig, ) -> LayerOffloadConductor: return enable_checkpointing(model, config, config.compile, [ - (model.down_blocks, []), - (model.up_blocks, []), + (SDCascadeResBlock, []), + (SDCascadeAttnBlock, []), + (SDCascadeTimestepBlock, []), ]) def enable_checkpointing_for_t5_encoder_layers( @@ -220,7 +284,7 @@ def enable_checkpointing_for_t5_encoder_layers( config: TrainConfig, ) -> LayerOffloadConductor: return enable_checkpointing(model, config, False, [ - (model.encoder.block, []), + (T5Block, []), ]) @@ -229,7 +293,7 @@ def enable_checkpointing_for_gemma_layers( config: TrainConfig, ) -> LayerOffloadConductor: return enable_checkpointing(model, config, False, [ - (model.layers, []), + (Gemma2DecoderLayer, []), ]) @@ -238,7 +302,7 @@ def enable_checkpointing_for_llama_encoder_layers( config: TrainConfig, ) -> LayerOffloadConductor: return enable_checkpointing(model, config, False, [ - (model.model.layers, []), + (LlamaDecoderLayer, []), ]) def enable_checkpointing_for_qwen_encoder_layers( @@ -246,7 +310,7 @@ def enable_checkpointing_for_qwen_encoder_layers( config: TrainConfig, ) -> LayerOffloadConductor: return enable_checkpointing(model, config, False, [ - (model.model.language_model.layers, []), # TODO No activation offloading for other encoders, see above. But clip skip is not implemented for QwenVL. Then do activation offloading? + (Qwen2_5_VLDecoderLayer, []), # TODO No activation offloading for other encoders, see above. But clip skip is not implemented for QwenVL. Then do activation offloading? ]) def enable_checkpointing_for_stable_diffusion_3_transformer( @@ -254,7 +318,7 @@ def enable_checkpointing_for_stable_diffusion_3_transformer( config: TrainConfig, ) -> LayerOffloadConductor: return enable_checkpointing(model, config, config.compile, [ - (model.transformer_blocks, ["hidden_states", "encoder_hidden_states"]), + (JointTransformerBlock, ["hidden_states", "encoder_hidden_states"]), ]) def enable_checkpointing_for_flux_transformer( @@ -291,7 +355,7 @@ def enable_checkpointing_for_sana_transformer( config: TrainConfig, ) -> LayerOffloadConductor: return enable_checkpointing(model, config, config.compile, [ - (model.transformer_blocks, ["hidden_states"]), + (SanaTransformerBlock, ["hidden_states"]), ]) def enable_checkpointing_for_hunyuan_video_transformer( @@ -299,9 +363,9 @@ def enable_checkpointing_for_hunyuan_video_transformer( config: TrainConfig, ) -> LayerOffloadConductor: return enable_checkpointing(model, config, config.compile, [ - (model.context_embedder.token_refiner.refiner_blocks, ["hidden_states" ]), - (model.transformer_blocks, ["hidden_states", "encoder_hidden_states"]), - (model.single_transformer_blocks, ["hidden_states" ]), + (HunyuanVideoIndividualTokenRefinerBlock, ["hidden_states" ]), + (HunyuanVideoTransformerBlock, ["hidden_states", "encoder_hidden_states"]), + (HunyuanVideoSingleTransformerBlock, ["hidden_states" ]), ]) def enable_checkpointing_for_hi_dream_transformer( @@ -309,6 +373,6 @@ def enable_checkpointing_for_hi_dream_transformer( config: TrainConfig, ) -> LayerOffloadConductor: return enable_checkpointing(model, config, config.compile, [ - (model.double_stream_blocks, ["hidden_states", "encoder_hidden_states"]), - (model.single_stream_blocks, ["hidden_states" ]), + (HiDreamImageTransformerBlock, ["hidden_states", "encoder_hidden_states"]), + (HiDreamImageSingleTransformerBlock, ["hidden_states" ]), ]) diff --git a/modules/util/config/TrainConfig.py b/modules/util/config/TrainConfig.py index 9b6419fa0..380b6a3fa 100644 --- a/modules/util/config/TrainConfig.py +++ b/modules/util/config/TrainConfig.py @@ -871,8 +871,8 @@ def default_values() -> 'TrainConfig': data.append(("layer_offload_fraction", 0.0, float, False)) data.append(("force_circular_padding", False, bool, False)) data.append(("compile", True, bool, False)) - data.append(("svd_dtype", DataType.FLOAT_32, DataType, False)) - data.append(("svd_rank", 16, int, False)) + data.append(("svd_dtype", DataType.BFLOAT_16, DataType, False)) + data.append(("svd_rank", 128, int, False)) # data settings data.append(("concept_file_name", "training_concepts/concepts.json", str, False)) diff --git a/modules/util/quantization_util.py b/modules/util/quantization_util.py index 3c827ad92..ceeca15f3 100644 --- a/modules/util/quantization_util.py +++ b/modules/util/quantization_util.py @@ -4,7 +4,7 @@ from modules.module.quantized.LinearFp8 import LinearFp8 from modules.module.quantized.LinearSVD import BaseLinearSVD, make_svd_linear -from modules.module.quantized.LinearW8A8 import LinearW8A8 +from modules.module.quantized.LinearW8A8 import LinearW8A8, make_requant from modules.module.quantized.mixin.QuantizedLinearMixin import QuantizedLinearMixin from modules.module.quantized.mixin.QuantizedModuleMixin import QuantizedModuleMixin from modules.util.config.TrainConfig import TrainConfig @@ -57,7 +57,7 @@ def __replace_linear_layers( visited_modules.add(id(parent_module)) - if isinstance(parent_module, nn.ModuleList): + if isinstance(parent_module, (nn.ModuleList, nn.Sequential)): for i, module in enumerate(parent_module): if isinstance(module, nn.Linear): quant_linear = __create_linear_layer(construct_fn, module, copy_parameters) @@ -92,6 +92,10 @@ def __replace_linear_layers( visited_modules=visited_modules, ) + for name, module in parent_module.named_modules(): + #ensure that all Linear layers were replaced + #https://github.com/Nerogar/OneTrainer/issues/1050 + assert not isinstance(module, nn.Linear) or isinstance(module, QuantizedLinearMixin), f"Linear layer {name} was not found in model for quantization" def replace_linear_with_quantized_layers( parent_module: nn.Module, @@ -100,7 +104,7 @@ def replace_linear_with_quantized_layers( copy_parameters: bool = False, ): if dtype.quantize_nf4(): - construct_fn = make_svd_linear(LinearNf4) if dtype.quantize_svd() else LinearNf4 + construct_fn = partial(make_svd_linear(make_requant(LinearNf4)), dtype=torch.int8, compute_dtype=torch.bfloat16) if dtype.quantize_svd() else LinearNf4 elif dtype.quantize_int8(): construct_fn = partial(make_svd_linear(bnb.nn.Linear8bitLt) if dtype.quantize_svd() else bnb.nn.Linear8bitLt, has_fp16_weights=False) elif dtype.quantize_fp8(): @@ -185,7 +189,8 @@ def get_offload_tensors(module: nn.Module) -> list[torch.Tensor]: if isinstance(module, nn.Linear) and module.bias is not None: tensors += [module.bias] if isinstance(module, BaseLinearSVD): - tensors += [module.svd_up, module.svd_down] + tensors += [module.svd_up] + tensors += [module.svd_down] return tensors diff --git a/requirements-cuda.txt b/requirements-cuda.txt index c60082db5..c17429e6f 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -4,6 +4,7 @@ torch==2.8.0+cu128 torchvision==0.23.0+cu128 onnxruntime-gpu==1.22.0 nvidia-nccl-cu12==2.27.3; sys_platform == "linux" +triton-windows==3.4.0.post20; sys_platform == "win32" # optimizers bitsandbytes==0.46.0 # bitsandbytes for 8-bit optimizers and weight quantization diff --git a/requirements-global.txt b/requirements-global.txt index 2111386eb..bb812fa47 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -1,7 +1,7 @@ # base requirements numpy==2.2.6 opencv-python==4.11.0.86 -pillow==11.2.1 +pillow==11.3.0 imagesize==1.4.1 #for concept statistics tqdm==4.67.1 PyYAML==6.0.2 @@ -22,8 +22,8 @@ pytorch-lightning==2.5.1.post0 #Note: check whether Qwen bugs in diffusers have been fixed before upgrading diffusers (see BaseQwenSetup): -e git+https://github.com/huggingface/diffusers.git@9b721db#egg=diffusers -transformers==4.52.4 -sentencepiece==0.2.0 # transitive dependency of transformers for tokenizer loading +transformers==4.56.2 +sentencepiece==0.2.1 # transitive dependency of transformers for tokenizer loading omegaconf==2.3.0 # needed to load stable diffusion from single ckpt files invisible-watermark==0.2.0 # needed for the SDXL pipeline @@ -44,7 +44,7 @@ prodigyopt==1.1.2 # prodigy optimizer schedulefree==1.4.1 # schedule-free optimizers pytorch_optimizer==3.6.0 # pytorch optimizers prodigy-plus-schedule-free==2.0.0 # Prodigy plus optimizer -adv_optm==1.0.3 # advanced optimizers +adv_optm==1.0.6 # advanced optimizers # Profiling scalene==1.5.51 @@ -59,4 +59,4 @@ fabric==3.2.2 # debug psutil==7.0.0 requests==2.32.3 -deepdiff==8.5.0 +deepdiff==8.6.1 # output easy to read diff for troublshooting diff --git a/training_presets/#qwen LoRA 16GB.json b/training_presets/#qwen LoRA 16GB.json index 91dff7b6f..eae8bad72 100644 --- a/training_presets/#qwen LoRA 16GB.json +++ b/training_presets/#qwen LoRA 16GB.json @@ -5,7 +5,7 @@ "model_type": "QWEN", "resolution": "512", "gradient_checkpointing": "CPU_OFFLOADED", - "layer_offload_fraction": 0.65, + "layer_offload_fraction": 0.5, "dataloader_threads": 1, "prior": { "train": true, diff --git a/training_presets/#qwen LoRA 24GB.json b/training_presets/#qwen LoRA 24GB.json index ab3e10bcf..6b76111e2 100644 --- a/training_presets/#qwen LoRA 24GB.json +++ b/training_presets/#qwen LoRA 24GB.json @@ -5,7 +5,7 @@ "model_type": "QWEN", "resolution": "512", "gradient_checkpointing": "CPU_OFFLOADED", - "layer_offload_fraction": 0.35, + "layer_offload_fraction": 0.1, "dataloader_threads": 1, "prior": { "train": true, diff --git a/update.sh b/update.sh index 0d12760fa..5656a45e5 100755 --- a/update.sh +++ b/update.sh @@ -7,7 +7,7 @@ cd -- "$(dirname -- "${BASH_SOURCE[0]}")" # Pull the latest changes via Git. echo "[OneTrainer] Updating OneTrainer to latest version from Git repository..." -git pull +#git pull # Load the newest version of the function library. source "lib.include.sh" From cb02a4c9082d7d350eeef1062d9b2f12624a5307 Mon Sep 17 00:00:00 2001 From: dxqb Date: Mon, 13 Oct 2025 11:41:50 +0200 Subject: [PATCH 07/54] various --- modules/ui/ModelTab.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index 2407a1bcf..063ddd00e 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -264,7 +264,7 @@ def __create_dtype_options(self, include_none: bool=True, include_svd: bool=Fals ("float16", DataType.FLOAT_16), ("float8 (W8)", DataType.FLOAT_8), ("float W8A8", DataType.FLOAT_W8A8), - ("int W8A8", DataType.INT_W8A8), + #("int W8A8", DataType.INT_W8A8), # ("int8", DataType.INT_8), # TODO: reactivate when the int8 implementation is fixed in bitsandbytes: https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1332 ("nfloat4", DataType.NFLOAT_4), ] From efb9073a328e3cb090152b12555bcb4b5ef78472 Mon Sep 17 00:00:00 2001 From: dxqb Date: Mon, 13 Oct 2025 11:53:50 +0200 Subject: [PATCH 08/54] various --- modules/module/quantized/LinearW8A8.py | 4 +--- modules/ui/ModelTab.py | 4 ++-- modules/util/quantization_util.py | 4 ++-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index 87838d483..1e7c3aceb 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -70,11 +70,9 @@ def int8_forward_tokenwise(x: Tensor, weight: float | Tensor, weight_scale: floa def fp8_forward_tokenwise(x: Tensor, weight: float | Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: x_8, x_scale = quantize_fp8_tokenwise(x) one = torch.ones(1, device=x.device) - #res = torch._scaled_mm(x_8, weight.T, scale_a=one, scale_b=weight_scale.float(), out_dtype=x.dtype) #FIXME TODO test difference - res = torch._scaled_mm(x_8, weight.T, scale_a=one, scale_b=weight_scale.float(), out_dtype=torch.float32) + res = torch._scaled_mm(x_8, weight.T, scale_a=one, scale_b=weight_scale.float(), out_dtype=x.dtype) res_scaled = res.mul_(x_scale) #much faster than scaled by _scaled_mm - res_scaled = res_scaled.to(x.dtype) #FIXME if bias is not None: res_scaled.add_(bias.to(x.dtype)) return res_scaled diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index 063ddd00e..76965cdd1 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -264,7 +264,7 @@ def __create_dtype_options(self, include_none: bool=True, include_svd: bool=Fals ("float16", DataType.FLOAT_16), ("float8 (W8)", DataType.FLOAT_8), ("float W8A8", DataType.FLOAT_W8A8), - #("int W8A8", DataType.INT_W8A8), + #("int W8A8", DataType.INT_W8A8), #not recommended # ("int8", DataType.INT_8), # TODO: reactivate when the int8 implementation is fixed in bitsandbytes: https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1332 ("nfloat4", DataType.NFLOAT_4), ] @@ -274,7 +274,7 @@ def __create_dtype_options(self, include_none: bool=True, include_svd: bool=Fals ("float8 (W8) SVDQuant", DataType.FLOAT_8_SVD), ("float W8A8 SVDQuant", DataType.FLOAT_W8A8_SVD), ("int W8A8 SVDQuant", DataType.INT_W8A8_SVD), - ("nf4 -> int W8A8 SVD", DataType.NFLOAT_4_SVD), + ("nfloat4 SVD", DataType.NFLOAT_4_SVD), ] if include_none: diff --git a/modules/util/quantization_util.py b/modules/util/quantization_util.py index ceeca15f3..a8adf584b 100644 --- a/modules/util/quantization_util.py +++ b/modules/util/quantization_util.py @@ -4,7 +4,7 @@ from modules.module.quantized.LinearFp8 import LinearFp8 from modules.module.quantized.LinearSVD import BaseLinearSVD, make_svd_linear -from modules.module.quantized.LinearW8A8 import LinearW8A8, make_requant +from modules.module.quantized.LinearW8A8 import LinearW8A8 from modules.module.quantized.mixin.QuantizedLinearMixin import QuantizedLinearMixin from modules.module.quantized.mixin.QuantizedModuleMixin import QuantizedModuleMixin from modules.util.config.TrainConfig import TrainConfig @@ -104,7 +104,7 @@ def replace_linear_with_quantized_layers( copy_parameters: bool = False, ): if dtype.quantize_nf4(): - construct_fn = partial(make_svd_linear(make_requant(LinearNf4)), dtype=torch.int8, compute_dtype=torch.bfloat16) if dtype.quantize_svd() else LinearNf4 + construct_fn = make_svd_linear(LinearNf4) if dtype.quantize_svd() else LinearNf4 elif dtype.quantize_int8(): construct_fn = partial(make_svd_linear(bnb.nn.Linear8bitLt) if dtype.quantize_svd() else bnb.nn.Linear8bitLt, has_fp16_weights=False) elif dtype.quantize_fp8(): From 35ba023f999c5b15d62f3c6837603d41f5023bc3 Mon Sep 17 00:00:00 2001 From: dxqb Date: Mon, 13 Oct 2025 12:30:49 +0200 Subject: [PATCH 09/54] cleanup --- modules/module/quantized/LinearW8A8.py | 55 -------------------------- 1 file changed, 55 deletions(-) diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index 1e7c3aceb..8bd2ec18c 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -162,11 +162,6 @@ def quantize(self, device: torch.device | None = None, **kwargs): self.scale.copy_(scale) def forward(self, x_orig: torch.Tensor) -> torch.Tensor: - #calculate validation loss using 16 bit math: - #if not self.training: - # w = unquantize(self.weight, self.scale, compute_dtype=x_orig.dtype) - # return torch.nn.functional.linear(x_orig, w, self.bias) - assert not self.weight.requires_grad assert self.__is_quantized x = x_orig.to(self._compute_dtype).reshape(-1, x_orig.shape[-1]) @@ -183,56 +178,6 @@ def forward(self, x_orig: torch.Tensor) -> torch.Tensor: assert y.dtype == self._compute_dtype return y.reshape(x_orig.shape[:-1] + (y.shape[-1], )) -class LinearRequantInt8Function(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, weight_orig: Tensor, bias: Tensor | None) -> Tensor: - weight, weight_scale = quantize_int8_tensorwise(weight_orig) - ctx.save_for_backward(weight, weight_scale) - return int8_forward_tokenwise(x, weight, weight_scale, bias) - - @staticmethod - def backward(ctx, x: Tensor): - if ctx.needs_input_grad != (True, False, False): - raise NotImplementedError("NF4 cannot be used for full finetuning") - - weight, weight_scale = ctx.saved_tensors - return int8_backward_W_tensorwise_A_columnwise(x, weight, weight_scale), None, None, None - - -def make_requant(linear_class): - class LinearRequantA8(linear_class): - def __init__(self, dtype, compute_dtype, *args, **kwargs): - super().__init__(*args, **kwargs) - - assert dtype in [torch.int8, torch.float8_e4m3fn] - self._dtype = dtype - self._compute_dtype = compute_dtype - - def quantize(self, device: torch.device | None = None, **kwargs): - super().quantize(device, **kwargs) - - def forward(self, x_orig: torch.Tensor) -> torch.Tensor: - #calculate validation loss using 16 bit math: - #if not self.training: - # w = super().unquantized_weight(dtype=self._compute_dtype, device=x_orig.device) - # return torch.nn.functional.linear(x_orig, w, self.bias) - - assert not self.weight.requires_grad - x = x_orig.to(self._compute_dtype).reshape(-1, x_orig.shape[-1]) - w = super().unquantized_weight(dtype=self._compute_dtype, device=x.device) - if x.shape[0] > 16: - if self._dtype == torch.int8: - y = LinearRequantInt8Function.apply(x, w, self.bias) - else: - raise NotImplementedError - else: - y = torch.nn.functional.linear(x, w, self.bias.to(self._compute_dtype)) - - assert y.dtype == self._compute_dtype - return y.reshape(x_orig.shape[:-1] + (y.shape[-1], )) - return LinearRequantA8 - - def run_benchmark(fn, desc, steps=10000, warmup=500): from tqdm import tqdm for _ in range(warmup): From 5633b210988f323c57247ce52243aa83d36364cb Mon Sep 17 00:00:00 2001 From: dxqb Date: Tue, 14 Oct 2025 18:34:49 +0200 Subject: [PATCH 10/54] torch.compile bug workaround --- modules/model/ChromaModel.py | 5 +++++ modules/model/QwenModel.py | 5 +++++ modules/modelSetup/BaseChromaSetup.py | 2 ++ 3 files changed, 12 insertions(+) diff --git a/modules/model/ChromaModel.py b/modules/model/ChromaModel.py index 7b10891ff..14a0b6ffe 100644 --- a/modules/model/ChromaModel.py +++ b/modules/model/ChromaModel.py @@ -217,6 +217,11 @@ def encode_text( #prune tokens that are masked in all batch samples: seq_lengths = bool_attention_mask.sum(dim=1) max_seq_length = seq_lengths.max().item() + + if max_seq_length % 16 > 0: + #attention processors and/or torch.compile can have issues with uneven sequence length: + max_seq_length += (16 - max_seq_length % 16) + text_encoder_output = text_encoder_output[:, :max_seq_length, :] bool_attention_mask = bool_attention_mask[:, :max_seq_length] diff --git a/modules/model/QwenModel.py b/modules/model/QwenModel.py index b2748de6a..c92f8393e 100644 --- a/modules/model/QwenModel.py +++ b/modules/model/QwenModel.py @@ -175,6 +175,11 @@ def encode_text( #https://github.com/huggingface/diffusers/issues/12344 seq_lengths = tokens_mask.sum(dim=1) max_seq_length = seq_lengths.max().item() + + if max_seq_length % 16 > 0: + #attention processors and/or torch.compile can have issues with uneven sequence length: + max_seq_length += (16 - max_seq_length % 16) + text_encoder_output = text_encoder_output[:, :max_seq_length, :] bool_attention_mask = tokens_mask[:, :max_seq_length].bool() diff --git a/modules/modelSetup/BaseChromaSetup.py b/modules/modelSetup/BaseChromaSetup.py index 96544e122..6a2c60d60 100644 --- a/modules/modelSetup/BaseChromaSetup.py +++ b/modules/modelSetup/BaseChromaSetup.py @@ -226,9 +226,11 @@ def predict( packed_latent_input = model.pack_latents(latent_input) image_seq_len = packed_latent_input.shape[1] + text_seq_len = text_encoder_output.shape[1] image_attention_mask = torch.full((packed_latent_input.shape[0], image_seq_len), True, dtype=torch.bool, device=text_attention_mask.device) attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) + assert image_seq_len % 16 == 0 and (image_seq_len + text_seq_len) % 16 == 0 packed_predicted_flow = model.transformer( hidden_states=packed_latent_input.to(dtype=model.train_dtype.torch_dtype()), timestep=timestep / 1000, From c37f8052688a1571bf99755478fc1b26dfebe1cd Mon Sep 17 00:00:00 2001 From: dxqb Date: Tue, 14 Oct 2025 19:30:12 +0200 Subject: [PATCH 11/54] same workaround for Qwen --- modules/model/QwenModel.py | 6 ++---- modules/modelSetup/BaseQwenSetup.py | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/modules/model/QwenModel.py b/modules/model/QwenModel.py index c92f8393e..4c72e0fda 100644 --- a/modules/model/QwenModel.py +++ b/modules/model/QwenModel.py @@ -140,7 +140,7 @@ def encode_text( tokenizer_output = self.tokenizer( text, max_length=PROMPT_MAX_LENGTH + DEFAULT_PROMPT_TEMPLATE_CROP_START, - padding='longest', + padding='max_length', truncation=True, return_tensors="pt" ) @@ -168,9 +168,7 @@ def encode_text( if text_encoder_dropout_probability is not None and text_encoder_dropout_probability > 0.0: raise NotImplementedError #https://github.com/Nerogar/OneTrainer/issues/957 - #prune tokens that are masked in all batch samples - #this is still necessary even though we are using 'longest' padding, because cached - #encoder outputs by MGDS are always PROMPT_MAX_LENGTH + #prune tokens that are masked in all batch samples: #this is good for efficiency, but also FIXME currently required by the diffusers pipeline: #https://github.com/huggingface/diffusers/issues/12344 seq_lengths = tokens_mask.sum(dim=1) diff --git a/modules/modelSetup/BaseQwenSetup.py b/modules/modelSetup/BaseQwenSetup.py index f17fa9c86..1239cab56 100644 --- a/modules/modelSetup/BaseQwenSetup.py +++ b/modules/modelSetup/BaseQwenSetup.py @@ -131,10 +131,10 @@ def predict( latent_input = scaled_noisy_latent_image packed_latent_input = model.pack_latents(latent_input) - txt_seq_lens = text_attention_mask.sum(dim=1).tolist() #FIXME this is the only case that the transformer accepts: #see https://github.com/huggingface/diffusers/issues/12344 - assert max(txt_seq_lens) == text_encoder_output.shape[1] + #actual text sequence lengths can be shorter,but they might be padded and masked + txt_seq_lens = [text_encoder_output.shape[1]] * text_encoder_output.shape[0] #FIXME list of lists is not according to type hint, but according to diffusers code: #https://github.com/huggingface/diffusers/issues/12295 From d7532dc3e9d258360150e4e7bb4b668ccea7562f Mon Sep 17 00:00:00 2001 From: dxqb Date: Wed, 15 Oct 2025 13:03:48 +0200 Subject: [PATCH 12/54] gguf --- modules/modelLoader/chroma/ChromaModelLoader.py | 5 ++++- modules/modelLoader/flux/FluxModelLoader.py | 5 ++++- modules/modelLoader/qwen/QwenModelLoader.py | 5 ++++- modules/ui/ModelTab.py | 7 +++++-- modules/util/enum/DataType.py | 1 + 5 files changed, 18 insertions(+), 5 deletions(-) diff --git a/modules/modelLoader/chroma/ChromaModelLoader.py b/modules/modelLoader/chroma/ChromaModelLoader.py index c66817933..0f24f6c13 100644 --- a/modules/modelLoader/chroma/ChromaModelLoader.py +++ b/modules/modelLoader/chroma/ChromaModelLoader.py @@ -3,6 +3,7 @@ from modules.model.ChromaModel import ChromaModel from modules.modelLoader.mixin.HFModelLoaderMixin import HFModelLoaderMixin +from modules.util.enum.DataType import DataType from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes @@ -13,6 +14,7 @@ AutoencoderKL, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler, + GGUFQuantizationConfig, ) from transformers import T5EncoderModel, T5Tokenizer @@ -98,7 +100,8 @@ def __load_diffusers( transformer = ChromaTransformer2DModel.from_single_file( transformer_model_name, #avoid loading the transformer in float32: - torch_dtype = torch.bfloat16 if weight_dtypes.prior.torch_dtype() is None else weight_dtypes.prior.torch_dtype() + torch_dtype = torch.bfloat16 if weight_dtypes.prior.torch_dtype() is None else weight_dtypes.prior.torch_dtype(), + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.prior == DataType.GGUF else None, ) transformer = self._convert_diffusers_sub_module_to_dtype( transformer, weight_dtypes.prior, weight_dtypes.train_dtype diff --git a/modules/modelLoader/flux/FluxModelLoader.py b/modules/modelLoader/flux/FluxModelLoader.py index 5e07da653..8447a9a15 100644 --- a/modules/modelLoader/flux/FluxModelLoader.py +++ b/modules/modelLoader/flux/FluxModelLoader.py @@ -3,6 +3,7 @@ from modules.model.FluxModel import FluxModel from modules.modelLoader.mixin.HFModelLoaderMixin import HFModelLoaderMixin +from modules.util.enum.DataType import DataType from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes @@ -14,6 +15,7 @@ FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, + GGUFQuantizationConfig, ) from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer @@ -134,7 +136,8 @@ def __load_diffusers( transformer = FluxTransformer2DModel.from_single_file( transformer_model_name, #avoid loading the transformer in float32: - torch_dtype = torch.bfloat16 if weight_dtypes.prior.torch_dtype() is None else weight_dtypes.prior.torch_dtype() + torch_dtype = torch.bfloat16 if weight_dtypes.prior.torch_dtype() is None else weight_dtypes.prior.torch_dtype(), + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.prior == DataType.GGUF else None, ) transformer = self._convert_diffusers_sub_module_to_dtype( transformer, weight_dtypes.prior, weight_dtypes.train_dtype diff --git a/modules/modelLoader/qwen/QwenModelLoader.py b/modules/modelLoader/qwen/QwenModelLoader.py index 883af1fb2..0a1f6de7e 100644 --- a/modules/modelLoader/qwen/QwenModelLoader.py +++ b/modules/modelLoader/qwen/QwenModelLoader.py @@ -3,6 +3,7 @@ from modules.model.QwenModel import QwenModel from modules.modelLoader.mixin.HFModelLoaderMixin import HFModelLoaderMixin +from modules.util.enum.DataType import DataType from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes @@ -12,6 +13,7 @@ from diffusers import ( AutoencoderKLQwenImage, FlowMatchEulerDiscreteScheduler, + GGUFQuantizationConfig, QwenImageTransformer2DModel, ) from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer @@ -100,7 +102,8 @@ def __load_diffusers( config=base_model_name, subfolder="transformer", #avoid loading the transformer in float32: - torch_dtype = torch.bfloat16 if weight_dtypes.prior.torch_dtype() is None else weight_dtypes.prior.torch_dtype() + torch_dtype = torch.bfloat16 if weight_dtypes.prior.torch_dtype() is None else weight_dtypes.prior.torch_dtype(), + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.prior == DataType.GGUF else None, ) transformer = self._convert_diffusers_sub_module_to_dtype( transformer, weight_dtypes.prior, weight_dtypes.train_dtype diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index d27118bb3..4ecd80676 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -257,7 +257,7 @@ def __setup_hi_dream_ui(self): allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __create_dtype_options(self, include_none:bool=True) -> list[tuple[str, DataType]]: + def __create_dtype_options(self, include_none:bool=True, include_gguf=False) -> list[tuple[str, DataType]]: options = [ ("float32", DataType.FLOAT_32), ("bfloat16", DataType.BFLOAT_16), @@ -267,6 +267,9 @@ def __create_dtype_options(self, include_none:bool=True) -> list[tuple[str, Data ("nfloat4", DataType.NFLOAT_4), ] + if include_gguf: + options.append(("GGUF", DataType.GGUF)) + if include_none: options.insert(0, ("", DataType.NONE)) @@ -336,7 +339,7 @@ def __create_base_components( # prior weight dtype components.label(self.scroll_frame, row, 3, "Override Prior Data Type", tooltip="Overrides the prior weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(include_gguf=True), self.ui_state, "prior.weight_dtype") row += 1 diff --git a/modules/util/enum/DataType.py b/modules/util/enum/DataType.py index 45e547507..dd15e01d8 100644 --- a/modules/util/enum/DataType.py +++ b/modules/util/enum/DataType.py @@ -12,6 +12,7 @@ class DataType(Enum): TFLOAT_32 = 'TFLOAT_32' INT_8 = 'INT_8' NFLOAT_4 = 'NFLOAT_4' + GGUF = 'GGUF' def __str__(self): return self.value From 2de19c9f08ad81dcdb5798e10f8bab45b0f6cef9 Mon Sep 17 00:00:00 2001 From: dxqb Date: Wed, 15 Oct 2025 16:05:03 +0200 Subject: [PATCH 13/54] gguf --- .../modelLoader/chroma/ChromaModelLoader.py | 3 +- modules/modelLoader/flux/FluxModelLoader.py | 3 +- modules/modelLoader/qwen/QwenModelLoader.py | 3 +- modules/module/quantized/LinearGGUFA8.py | 40 +++++++++++++++++++ modules/ui/ModelTab.py | 2 + modules/util/enum/DataType.py | 7 ++++ modules/util/quantization_util.py | 23 +++++++---- 7 files changed, 68 insertions(+), 13 deletions(-) create mode 100644 modules/module/quantized/LinearGGUFA8.py diff --git a/modules/modelLoader/chroma/ChromaModelLoader.py b/modules/modelLoader/chroma/ChromaModelLoader.py index 0f24f6c13..b70253e17 100644 --- a/modules/modelLoader/chroma/ChromaModelLoader.py +++ b/modules/modelLoader/chroma/ChromaModelLoader.py @@ -3,7 +3,6 @@ from modules.model.ChromaModel import ChromaModel from modules.modelLoader.mixin.HFModelLoaderMixin import HFModelLoaderMixin -from modules.util.enum.DataType import DataType from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes @@ -101,7 +100,7 @@ def __load_diffusers( transformer_model_name, #avoid loading the transformer in float32: torch_dtype = torch.bfloat16 if weight_dtypes.prior.torch_dtype() is None else weight_dtypes.prior.torch_dtype(), - quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.prior == DataType.GGUF else None, + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.prior.is_gguf() else None, ) transformer = self._convert_diffusers_sub_module_to_dtype( transformer, weight_dtypes.prior, weight_dtypes.train_dtype diff --git a/modules/modelLoader/flux/FluxModelLoader.py b/modules/modelLoader/flux/FluxModelLoader.py index 8447a9a15..47cc99095 100644 --- a/modules/modelLoader/flux/FluxModelLoader.py +++ b/modules/modelLoader/flux/FluxModelLoader.py @@ -3,7 +3,6 @@ from modules.model.FluxModel import FluxModel from modules.modelLoader.mixin.HFModelLoaderMixin import HFModelLoaderMixin -from modules.util.enum.DataType import DataType from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes @@ -137,7 +136,7 @@ def __load_diffusers( transformer_model_name, #avoid loading the transformer in float32: torch_dtype = torch.bfloat16 if weight_dtypes.prior.torch_dtype() is None else weight_dtypes.prior.torch_dtype(), - quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.prior == DataType.GGUF else None, + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.prior.is_gguf() else None, ) transformer = self._convert_diffusers_sub_module_to_dtype( transformer, weight_dtypes.prior, weight_dtypes.train_dtype diff --git a/modules/modelLoader/qwen/QwenModelLoader.py b/modules/modelLoader/qwen/QwenModelLoader.py index 0a1f6de7e..79d1eb2e7 100644 --- a/modules/modelLoader/qwen/QwenModelLoader.py +++ b/modules/modelLoader/qwen/QwenModelLoader.py @@ -3,7 +3,6 @@ from modules.model.QwenModel import QwenModel from modules.modelLoader.mixin.HFModelLoaderMixin import HFModelLoaderMixin -from modules.util.enum.DataType import DataType from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes @@ -103,7 +102,7 @@ def __load_diffusers( subfolder="transformer", #avoid loading the transformer in float32: torch_dtype = torch.bfloat16 if weight_dtypes.prior.torch_dtype() is None else weight_dtypes.prior.torch_dtype(), - quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.prior == DataType.GGUF else None, + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.prior.is_gguf() else None, ) transformer = self._convert_diffusers_sub_module_to_dtype( transformer, weight_dtypes.prior, weight_dtypes.train_dtype diff --git a/modules/module/quantized/LinearGGUFA8.py b/modules/module/quantized/LinearGGUFA8.py new file mode 100644 index 000000000..940bc5323 --- /dev/null +++ b/modules/module/quantized/LinearGGUFA8.py @@ -0,0 +1,40 @@ + +from modules.module.quantized.LinearW8A8 import ( + LinearFp8Function, + LinearInt8Function, + quantize_fp8_tensorwise, + quantize_int8_tensorwise, +) + +import torch + +from diffusers.quantizers.gguf.utils import GGUFLinear, dequantize_gguf_tensor + + +class LinearGGUFA8(GGUFLinear): + def __init__(self, dtype, compute_dtype, *args, **kwargs): + super().__init__(*args, **kwargs) + + assert dtype in [torch.int8, torch.float8_e4m3fn] + self._dtype = dtype + self._compute_dtype = compute_dtype + + + def forward(self, x_orig: torch.Tensor) -> torch.Tensor: + assert not self.weight.requires_grad + x = x_orig.to(self._compute_dtype).reshape(-1, x_orig.shape[-1]) + w = dequantize_gguf_tensor(self.weight) + + if x.shape[0] > 16: + if self._dtype == torch.int8: + #TODO tokenwise instead? Higher quality, but requires quantization on forward and backward + q, q_scale = quantize_int8_tensorwise(w) + y = LinearInt8Function.apply(x, q, q_scale, self.bias) + else: + q, q_scale = quantize_fp8_tensorwise(w) + y = LinearFp8Function.apply(x, q, q_scale, self.bias) + else: + y = torch.nn.functional.linear(x, w, self.bias.to(self._compute_dtype)) + + assert y.dtype == self._compute_dtype + return y.reshape(x_orig.shape[:-1] + (y.shape[-1], )) diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index 38a8ad886..da8fc0652 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -279,6 +279,8 @@ def __create_dtype_options(self, include_none:bool=True, include_gguf=False, inc if include_gguf: options.append(("GGUF", DataType.GGUF)) + options.append(("GGUF A8 float", DataType.GGUF_A8_FLOAT)) + options.append(("GGUF A8 int", DataType.GGUF_A8_INT)) if include_none: options.insert(0, ("", DataType.NONE)) diff --git a/modules/util/enum/DataType.py b/modules/util/enum/DataType.py index 68a4a4e33..25b314d8d 100644 --- a/modules/util/enum/DataType.py +++ b/modules/util/enum/DataType.py @@ -19,6 +19,8 @@ class DataType(Enum): FLOAT_W8A8_SVD = 'FLOAT_W8A8_SVD' INT_W8A8_SVD = 'INT_W8A8_SVD' GGUF = 'GGUF' + GGUF_A8_FLOAT = 'GGUF_A8_FLOAT' + GGUF_A8_INT = 'GGUF_A8_INT' def __str__(self): return self.value @@ -56,6 +58,11 @@ def is_quantized(self): DataType.INT_W8A8_SVD, DataType.NFLOAT_4_SVD] + def is_gguf(self): + return self in [DataType.GGUF, + DataType.GGUF_A8_FLOAT, + DataType.GGUF_A8_INT] + def quantize_fp8(self): return self == DataType.FLOAT_8 or self == DataType.FLOAT_8_SVD diff --git a/modules/util/quantization_util.py b/modules/util/quantization_util.py index a8adf584b..7bdefbe3e 100644 --- a/modules/util/quantization_util.py +++ b/modules/util/quantization_util.py @@ -3,6 +3,7 @@ from functools import partial from modules.module.quantized.LinearFp8 import LinearFp8 +from modules.module.quantized.LinearGGUFA8 import LinearGGUFA8 from modules.module.quantized.LinearSVD import BaseLinearSVD, make_svd_linear from modules.module.quantized.LinearW8A8 import LinearW8A8 from modules.module.quantized.mixin.QuantizedLinearMixin import QuantizedLinearMixin @@ -25,7 +26,6 @@ def __create_linear_layer(construct_fn, module: nn.Linear, copy_parameters: bool) -> nn.Module: bias = module.bias is not None - quant_linear = construct_fn( in_features=module.in_features, out_features=module.out_features, @@ -33,12 +33,14 @@ def __create_linear_layer(construct_fn, module: nn.Linear, copy_parameters: bool ) if copy_parameters: - quant_linear.weight = type(quant_linear.weight)(module.weight) + quant_linear.weight = type(quant_linear.weight)(module.weight, requires_grad=False) if bias: - quant_linear.bias = type(quant_linear.bias)(module.bias) + quant_linear.bias = type(quant_linear.bias)(module.bias, requires_grad=False) return quant_linear +from diffusers.quantizers.gguf.utils import GGUFLinear + def __replace_linear_layers( parent_module: nn.Module, @@ -47,6 +49,7 @@ def __replace_linear_layers( copy_parameters: bool = False, name_prefix: str = "", visited_modules: set[int] | None = None, + convert_type = nn.Linear, ): if keep_in_fp32_modules is None: keep_in_fp32_modules = [] @@ -59,7 +62,7 @@ def __replace_linear_layers( if isinstance(parent_module, (nn.ModuleList, nn.Sequential)): for i, module in enumerate(parent_module): - if isinstance(module, nn.Linear): + if isinstance(module, convert_type): quant_linear = __create_linear_layer(construct_fn, module, copy_parameters) parent_module[i] = quant_linear del module @@ -78,7 +81,7 @@ def __replace_linear_layers( continue module = getattr(parent_module, attr_name) - if isinstance(module, nn.Linear): + if isinstance(module, convert_type): quant_linear = __create_linear_layer(construct_fn, module, copy_parameters) setattr(parent_module, attr_name, quant_linear) del module @@ -95,7 +98,8 @@ def __replace_linear_layers( for name, module in parent_module.named_modules(): #ensure that all Linear layers were replaced #https://github.com/Nerogar/OneTrainer/issues/1050 - assert not isinstance(module, nn.Linear) or isinstance(module, QuantizedLinearMixin), f"Linear layer {name} was not found in model for quantization" + assert (not isinstance(module, convert_type) + or isinstance(module, QuantizedLinearMixin, LinearGGUFA8)), f"Linear layer {name_prefix}.{name} was not found in model for quantization" def replace_linear_with_quantized_layers( parent_module: nn.Module, @@ -112,7 +116,11 @@ def replace_linear_with_quantized_layers( elif dtype.quantize_intW8A8(): construct_fn = partial(make_svd_linear(LinearW8A8) if dtype.quantize_svd() else LinearW8A8, dtype=torch.int8, compute_dtype=torch.bfloat16) elif dtype.quantize_fpW8A8(): - construct_fn = partial(make_svd_linear(LinearW8A8) if dtype.quantize_svd() else LinearW8A8, dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + construct_fn = partial(make_svd_linear(LinearW8A8) if dtype.quantize_svd() else LinearW8A8, dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + elif dtype == DataType.GGUF_A8_INT: + construct_fn = partial(LinearGGUFA8, dtype=torch.int8, compute_dtype=torch.bfloat16) + elif dtype == DataType.GGUF_A8_FLOAT: + construct_fn = partial(LinearGGUFA8, dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) else: return @@ -121,6 +129,7 @@ def replace_linear_with_quantized_layers( construct_fn=construct_fn, keep_in_fp32_modules=keep_in_fp32_modules, copy_parameters=copy_parameters, + convert_type = GGUFLinear if dtype.is_gguf() else nn.Linear, ) From a3cd93636d7f6b9429ab044273ec2bd9976dbf57 Mon Sep 17 00:00:00 2001 From: dxqb Date: Wed, 15 Oct 2025 16:42:20 +0200 Subject: [PATCH 14/54] bugfix --- modules/util/quantization_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/util/quantization_util.py b/modules/util/quantization_util.py index 7bdefbe3e..5310f6d49 100644 --- a/modules/util/quantization_util.py +++ b/modules/util/quantization_util.py @@ -99,7 +99,7 @@ def __replace_linear_layers( #ensure that all Linear layers were replaced #https://github.com/Nerogar/OneTrainer/issues/1050 assert (not isinstance(module, convert_type) - or isinstance(module, QuantizedLinearMixin, LinearGGUFA8)), f"Linear layer {name_prefix}.{name} was not found in model for quantization" + or isinstance(module, (QuantizedLinearMixin, LinearGGUFA8))), f"Linear layer {name_prefix}.{name} was not found in model for quantization" def replace_linear_with_quantized_layers( parent_module: nn.Module, From bcf0b658c4891b0c03fafbfc7b4b1e2637da7fe5 Mon Sep 17 00:00:00 2001 From: dxqb Date: Wed, 15 Oct 2025 17:08:48 +0200 Subject: [PATCH 15/54] requirements --- requirements-global.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-global.txt b/requirements-global.txt index 85ee87118..88c43aa69 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -21,7 +21,7 @@ pytorch-lightning==2.5.1.post0 # diffusion models #Note: check whether Qwen bugs in diffusers have been fixed before upgrading diffusers (see BaseQwenSetup): -e git+https://github.com/huggingface/diffusers.git@9b721db#egg=diffusers - +gguf==0.17.1 transformers==4.52.4 sentencepiece==0.2.0 # transitive dependency of transformers for tokenizer loading omegaconf==2.3.0 # needed to load stable diffusion from single ckpt files From 5a2e59045b2bd126d639fd794e2f6a709a8e813b Mon Sep 17 00:00:00 2001 From: dxqb Date: Thu, 16 Oct 2025 13:47:52 +0200 Subject: [PATCH 16/54] name changes, axis wise --- modules/module/quantized/LinearW8A8.py | 38 +++++++++++++------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index 8bd2ec18c..495be8680 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -21,13 +21,13 @@ def quantize_int8_tensorwise(x: Tensor) -> tuple[Tensor, float]: q = quantize_int8(x, scale) return q, scale -def quantize_int8_tokenwise_get_scale(x: Tensor) -> Tensor: - abs_max = x.abs().amax(dim=-1, keepdim=True) +def quantize_int8_axiswise_get_scale(x: Tensor, dim: int) -> Tensor: + abs_max = x.abs().amax(dim=dim, keepdim=True) scale = (abs_max.float() / 127.0).clamp(min=1e-12) return scale -def quantize_int8_tokenwise(x: Tensor) -> tuple[Tensor, Tensor]: - scale = quantize_int8_tokenwise_get_scale(x) +def quantize_int8_axiswise(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: + scale = quantize_int8_axiswise_get_scale(x, dim) q = quantize_int8(x, scale) return q, scale @@ -40,8 +40,8 @@ def quantize_fp8_tensorwise_get_scale(x: Tensor) -> float: scale = (abs_max.float() / 448.0).clamp(min=1e-12) return scale -def quantize_fp8_tokenwise_get_scale(x: Tensor) -> Tensor: - abs_max = x.abs().amax(dim=-1, keepdim=True) +def quantize_fp8_axiswise_get_scale(x: Tensor, dim: int) -> Tensor: + abs_max = x.abs().amax(dim=dim, keepdim=True) scale = (abs_max.float() / 448.0).clamp(min=1e-12) return scale @@ -50,8 +50,8 @@ def quantize_fp8_tensorwise(x: Tensor) -> tuple[Tensor, float]: q = quantize_fp8(x, scale) return q, scale -def quantize_fp8_tokenwise(x: Tensor) -> tuple[Tensor, Tensor]: - scale = quantize_fp8_tokenwise_get_scale(x) +def quantize_fp8_axiswise(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: + scale = quantize_fp8_axiswise_get_scale(x, dim) q = quantize_fp8(x, scale) return q, scale @@ -59,16 +59,15 @@ def unquantize(q: Tensor, scale: float | Tensor, compute_dtype: torch.dtype) -> return q.to(compute_dtype) * scale.to(compute_dtype) def int8_forward_tokenwise(x: Tensor, weight: float | Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: - x_8, x_scale = quantize_int8_tokenwise(x) + x_8, x_scale = quantize_int8_axiswise(x, dim=-1) res = torch._int_mm(x_8, weight.T) res_scaled = res.to(x.dtype).mul_(weight_scale * x_scale) if bias is not None: res_scaled.add_(bias.to(x.dtype)) return res_scaled - def fp8_forward_tokenwise(x: Tensor, weight: float | Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: - x_8, x_scale = quantize_fp8_tokenwise(x) + x_8, x_scale = quantize_fp8_axiswise(x, dim=-1) one = torch.ones(1, device=x.device) res = torch._scaled_mm(x_8, weight.T, scale_a=one, scale_b=weight_scale.float(), out_dtype=x.dtype) res_scaled = res.mul_(x_scale) #much faster than scaled by _scaled_mm @@ -78,17 +77,18 @@ def fp8_forward_tokenwise(x: Tensor, weight: float | Tensor, weight_scale: float return res_scaled -def int8_backward_W_tensorwise_A_columnwise(x: Tensor, weight: Tensor, weight_scale: float) -> Tensor: - x_8, x_scale = quantize_int8_tokenwise(x) +def int8_backward_axiswise(x: Tensor, weight: Tensor, weight_scale: float) -> Tensor: + x_8, x_scale = quantize_int8_axiswise(x, dim=-1) mm_res = triton_mm_8bit(x_8, weight) return mm_res.to(x.dtype).mul_(weight_scale * x_scale) -def fp8_backward_W_tensorwise_A_columnwise(x: Tensor, weight: Tensor, weight_scale: float) -> Tensor: - x_8, x_scale = quantize_fp8_tokenwise(x) +def fp8_backward_axiswise(x: Tensor, weight: Tensor, weight_scale: float) -> Tensor: + x_8, x_scale = quantize_fp8_axiswise(x, dim=-1) mm_res = triton_mm_8bit(x_8, weight) return mm_res.to(x.dtype).mul_(weight_scale * x_scale) + class LinearInt8Function(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, weight: Tensor, weight_scale: float, bias: Tensor | None) -> Tensor: @@ -101,7 +101,7 @@ def backward(ctx, x: Tensor): raise NotImplementedError("Int A8W8 cannot be used for full finetuning") weight, weight_scale = ctx.saved_tensors - return int8_backward_W_tensorwise_A_columnwise(x, weight, weight_scale), None, None, None + return int8_backward_axiswise(x, weight, weight_scale), None, None, None class LinearFp8Function(torch.autograd.Function): @staticmethod @@ -115,7 +115,7 @@ def backward(ctx, x: Tensor): raise NotImplementedError("Float A8W8 cannot be used for full finetuning") weight, weight_scale = ctx.saved_tensors - return fp8_backward_W_tensorwise_A_columnwise(x, weight, weight_scale), None, None, None + return fp8_backward_axiswise(x, weight, weight_scale), None, None, None class LinearW8A8( nn.Linear, @@ -206,7 +206,7 @@ def torch_backward(a, b): run_benchmark(lambda: triton_mm_8bit(y_8, w_8), "triton mm backward int8") run_benchmark(lambda: int8_forward_tokenwise(x, w_8, w_scale), "torch forward int") - run_benchmark(lambda: int8_backward_W_tensorwise_A_columnwise(y, w_8, w_scale), "triton backward int") + run_benchmark(lambda: int8_backward_axiswise(y, w_8, w_scale), "triton backward int") @torch.no_grad() @@ -226,7 +226,7 @@ def torch_backward(a, b): run_benchmark(lambda: torch_backward(y_8, w_8), "torch mm backward fp8") run_benchmark(lambda: triton_mm_8bit(y_8, w_8), "triton mm backward fp8") run_benchmark(lambda: fp8_forward_tokenwise(x, w_8, w_scale), "torch forward fp8") - run_benchmark(lambda: fp8_backward_W_tensorwise_A_columnwise(y, w_8, w_scale), "triton backward fp8") + run_benchmark(lambda: fp8_backward_axiswise(y, w_8, w_scale), "triton backward fp8") if __name__ == "__main__": From 882154d46b372554f20130310442bdabcdc57192 Mon Sep 17 00:00:00 2001 From: dxqb Date: Thu, 16 Oct 2025 13:53:22 +0200 Subject: [PATCH 17/54] merge --- modules/module/quantized/LinearGGUFA8.py | 50 ++++++++++++++++++------ modules/module/quantized/LinearW8A8.py | 2 +- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/modules/module/quantized/LinearGGUFA8.py b/modules/module/quantized/LinearGGUFA8.py index 940bc5323..6de4afaa0 100644 --- a/modules/module/quantized/LinearGGUFA8.py +++ b/modules/module/quantized/LinearGGUFA8.py @@ -1,16 +1,43 @@ - -from modules.module.quantized.LinearW8A8 import ( - LinearFp8Function, - LinearInt8Function, - quantize_fp8_tensorwise, - quantize_int8_tensorwise, -) - +from modules.module.quantized.LinearW8A8 import quantize_fp8_axiswise, quantize_int8_axiswise, fp8_forward_tokenwise, int8_forward_tokenwise, int8_backward_W_tensorwise_A_axiswise, fp8_backward_W_tensorwise_A_axiswise import torch +from torch import Tensor from diffusers.quantizers.gguf.utils import GGUFLinear, dequantize_gguf_tensor +class LinearGGUFIntA8RequantFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: + ctx.save_for_backward(weight) + #axiswise performs better than tensorwise in tests, even though + #it requires another requant during backward - but requant is cheap + weight_q, weight_scale = quantize_int8_axiswise(weight, dim=-1) + return int8_forward_tokenwise(x, weight_q, weight_scale.T, bias) + + @staticmethod + def backward(ctx, x: Tensor): + if ctx.needs_input_grad != (True, False, False): + raise NotImplementedError("GGUF cannot be used for full finetuning") + weight, = ctx.saved_tensors + weight_q, weight_scale = quantize_int8_axiswise(weight, dim=0) + return int8_backward_W_tensorwise_A_axiswise(x, weight_q, weight_scale), None, None + +class LinearGGUFFpA8RequantFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: + ctx.save_for_backward(weight) + weight_q, weight_scale = quantize_fp8_axiswise(weight, dim=-1) + return fp8_forward_tokenwise(x, weight_q, weight_scale, bias) + + @staticmethod + def backward(ctx, x: Tensor): + if ctx.needs_input_grad != (True, False, False): + raise NotImplementedError("GGUF cannot be used for full finetuning") + weight, = ctx.saved_tensors + weight_q, weight_scale = quantize_fp8_axiswise(weight, dim=0) + return fp8_backward_W_tensorwise_A_axiswise(x, weight_q, weight_scale), None, None + + class LinearGGUFA8(GGUFLinear): def __init__(self, dtype, compute_dtype, *args, **kwargs): super().__init__(*args, **kwargs) @@ -27,12 +54,9 @@ def forward(self, x_orig: torch.Tensor) -> torch.Tensor: if x.shape[0] > 16: if self._dtype == torch.int8: - #TODO tokenwise instead? Higher quality, but requires quantization on forward and backward - q, q_scale = quantize_int8_tensorwise(w) - y = LinearInt8Function.apply(x, q, q_scale, self.bias) + y = LinearGGUFIntA8RequantFunction.apply(x, w, self.bias) else: - q, q_scale = quantize_fp8_tensorwise(w) - y = LinearFp8Function.apply(x, q, q_scale, self.bias) + y = LinearGGUFFpA8RequantFunction.apply(x, w, self.bias) else: y = torch.nn.functional.linear(x, w, self.bias.to(self._compute_dtype)) diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index 495be8680..cef372bb3 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -69,7 +69,7 @@ def int8_forward_tokenwise(x: Tensor, weight: float | Tensor, weight_scale: floa def fp8_forward_tokenwise(x: Tensor, weight: float | Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: x_8, x_scale = quantize_fp8_axiswise(x, dim=-1) one = torch.ones(1, device=x.device) - res = torch._scaled_mm(x_8, weight.T, scale_a=one, scale_b=weight_scale.float(), out_dtype=x.dtype) + res = torch._scaled_mm(x_8, weight.T, scale_a=one, scale_b=weight_scale.float(), out_dtype=x.dtype) res_scaled = res.mul_(x_scale) #much faster than scaled by _scaled_mm if bias is not None: From e5317d38d41273c303c324f00814a3e83c555565 Mon Sep 17 00:00:00 2001 From: dxqb Date: Thu, 16 Oct 2025 13:57:58 +0200 Subject: [PATCH 18/54] big type hint --- modules/module/quantized/LinearW8A8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index 495be8680..49ed6607d 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -58,7 +58,7 @@ def quantize_fp8_axiswise(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: def unquantize(q: Tensor, scale: float | Tensor, compute_dtype: torch.dtype) -> Tensor: return q.to(compute_dtype) * scale.to(compute_dtype) -def int8_forward_tokenwise(x: Tensor, weight: float | Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: +def int8_forward_tokenwise(x: Tensor, weight: Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: x_8, x_scale = quantize_int8_axiswise(x, dim=-1) res = torch._int_mm(x_8, weight.T) res_scaled = res.to(x.dtype).mul_(weight_scale * x_scale) @@ -66,7 +66,7 @@ def int8_forward_tokenwise(x: Tensor, weight: float | Tensor, weight_scale: floa res_scaled.add_(bias.to(x.dtype)) return res_scaled -def fp8_forward_tokenwise(x: Tensor, weight: float | Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: +def fp8_forward_tokenwise(x: Tensor, weight: Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: x_8, x_scale = quantize_fp8_axiswise(x, dim=-1) one = torch.ones(1, device=x.device) res = torch._scaled_mm(x_8, weight.T, scale_a=one, scale_b=weight_scale.float(), out_dtype=x.dtype) From c8cb33bc14c710e06ef639b4396bd662a9210026 Mon Sep 17 00:00:00 2001 From: dxqb Date: Thu, 16 Oct 2025 14:38:26 +0200 Subject: [PATCH 19/54] use axis-wise quantization for both forward and backward --- modules/module/quantized/LinearGGUFA8.py | 50 +++++++++++++++++++----- modules/module/quantized/LinearW8A8.py | 2 +- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/modules/module/quantized/LinearGGUFA8.py b/modules/module/quantized/LinearGGUFA8.py index 6de4afaa0..067cecabf 100644 --- a/modules/module/quantized/LinearGGUFA8.py +++ b/modules/module/quantized/LinearGGUFA8.py @@ -1,41 +1,73 @@ -from modules.module.quantized.LinearW8A8 import quantize_fp8_axiswise, quantize_int8_axiswise, fp8_forward_tokenwise, int8_forward_tokenwise, int8_backward_W_tensorwise_A_axiswise, fp8_backward_W_tensorwise_A_axiswise +from modules.module.quantized.LinearW8A8 import ( + quantize_fp8_axiswise, + quantize_int8_axiswise, +) +from modules.util.triton_mm_8bit import mm_8bit as triton_mm_8bit + import torch from torch import Tensor from diffusers.quantizers.gguf.utils import GGUFLinear, dequantize_gguf_tensor +def int8_forward_both_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> Tensor: + x_8, x_scale = quantize_int8_axiswise(x, dim=-1) + w_8, w_scale = quantize_int8_axiswise(weight, dim=-1) + res = torch._int_mm(x_8, w_8.T) + res_scaled = res.to(x.dtype).mul_(w_scale.T).mul_(x_scale) + if bias is not None: + res_scaled.add_(bias.to(x.dtype)) + return res_scaled + +def fp8_forward_both_axiswise(x: Tensor, weight: float | Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: + x_8, x_scale = quantize_fp8_axiswise(x, dim=-1) + w_8, w_scale = quantize_fp8_axiswise(weight, dim=-1) + one = torch.ones(1, device=x.device) + res = torch._scaled_mm(x_8, w_8.T, scale_a=one, scale_b=one, out_dtype=x.dtype) + res_scaled = res.mul_(w_scale.T).mul_(x_scale) #much faster than scaled by _scaled_mm + if bias is not None: + res_scaled.add_(bias.to(x.dtype)) + return res_scaled + +def int8_backward_both_axiswise(x: Tensor, weight: Tensor) -> Tensor: + x_8, x_scale = quantize_int8_axiswise(x, dim=-1) + w_8, w_scale = quantize_int8_axiswise(weight, dim=0) + mm_res = triton_mm_8bit(x_8, w_8) + return mm_res.to(x.dtype).mul_(w_scale).mul_(x_scale) + +def fp8_backward_both_axiswise(x: Tensor, weight: Tensor) -> Tensor: + x_8, x_scale = quantize_fp8_axiswise(x, dim=-1) + w_8, w_scale = quantize_fp8_axiswise(weight, dim=0) + mm_res = triton_mm_8bit(x_8, w_8) + return mm_res.to(x.dtype).mul_(w_scale).mul_(x_scale) + class LinearGGUFIntA8RequantFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: ctx.save_for_backward(weight) #axiswise performs better than tensorwise in tests, even though #it requires another requant during backward - but requant is cheap - weight_q, weight_scale = quantize_int8_axiswise(weight, dim=-1) - return int8_forward_tokenwise(x, weight_q, weight_scale.T, bias) + return int8_forward_both_axiswise(x, weight, bias) @staticmethod def backward(ctx, x: Tensor): if ctx.needs_input_grad != (True, False, False): raise NotImplementedError("GGUF cannot be used for full finetuning") weight, = ctx.saved_tensors - weight_q, weight_scale = quantize_int8_axiswise(weight, dim=0) - return int8_backward_W_tensorwise_A_axiswise(x, weight_q, weight_scale), None, None + return int8_backward_both_axiswise(x, weight), None, None class LinearGGUFFpA8RequantFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: ctx.save_for_backward(weight) - weight_q, weight_scale = quantize_fp8_axiswise(weight, dim=-1) - return fp8_forward_tokenwise(x, weight_q, weight_scale, bias) + return fp8_forward_both_axiswise(x, weight, bias) @staticmethod def backward(ctx, x: Tensor): if ctx.needs_input_grad != (True, False, False): raise NotImplementedError("GGUF cannot be used for full finetuning") weight, = ctx.saved_tensors - weight_q, weight_scale = quantize_fp8_axiswise(weight, dim=0) - return fp8_backward_W_tensorwise_A_axiswise(x, weight_q, weight_scale), None, None + return fp8_backward_both_axiswise(x, weight), None, None class LinearGGUFA8(GGUFLinear): diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index 16856eab9..49ed6607d 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -69,7 +69,7 @@ def int8_forward_tokenwise(x: Tensor, weight: Tensor, weight_scale: float, bias: def fp8_forward_tokenwise(x: Tensor, weight: Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: x_8, x_scale = quantize_fp8_axiswise(x, dim=-1) one = torch.ones(1, device=x.device) - res = torch._scaled_mm(x_8, weight.T, scale_a=one, scale_b=weight_scale.float(), out_dtype=x.dtype) + res = torch._scaled_mm(x_8, weight.T, scale_a=one, scale_b=weight_scale.float(), out_dtype=x.dtype) res_scaled = res.mul_(x_scale) #much faster than scaled by _scaled_mm if bias is not None: From 2d4a0c31c1535533fbfc8b194f6e62e96eef17c4 Mon Sep 17 00:00:00 2001 From: dxqb Date: Thu, 16 Oct 2025 15:58:01 +0200 Subject: [PATCH 20/54] initial --- modules/cloud/BaseCloud.py | 2 + .../modelLoader/chroma/ChromaModelLoader.py | 12 ++--- modules/modelLoader/flux/FluxModelLoader.py | 16 +++---- .../modelLoader/hiDream/HiDreamModelLoader.py | 4 +- .../hunyuanVideo/HunyuanVideoModelLoader.py | 4 +- .../pixartAlpha/PixArtAlphaModelLoader.py | 2 +- modules/modelLoader/qwen/QwenModelLoader.py | 12 ++--- modules/modelLoader/sana/SanaModelLoader.py | 2 +- .../StableDiffusion3ModelLoader.py | 4 +- modules/modelSampler/FluxSampler.py | 12 ++--- modules/modelSampler/HiDreamSampler.py | 8 ++-- modules/modelSampler/HunyuanVideoSampler.py | 4 +- .../modelSampler/StableDiffusion3Sampler.py | 8 ++-- modules/modelSetup/BaseChromaSetup.py | 2 +- modules/modelSetup/BaseFluxSetup.py | 6 +-- modules/modelSetup/BaseHiDreamSetup.py | 6 +-- modules/modelSetup/BaseHunyuanVideoSetup.py | 6 +-- modules/modelSetup/BasePixArtAlphaSetup.py | 2 +- modules/modelSetup/BaseQwenSetup.py | 2 +- modules/modelSetup/BaseSanaSetup.py | 2 +- .../modelSetup/BaseStableDiffusion3Setup.py | 4 +- modules/modelSetup/ChromaFineTuneSetup.py | 6 +-- modules/modelSetup/ChromaLoRASetup.py | 6 +-- modules/modelSetup/FluxFineTuneSetup.py | 6 +-- modules/modelSetup/FluxLoRASetup.py | 6 +-- modules/modelSetup/HiDreamFineTuneSetup.py | 6 +-- modules/modelSetup/HiDreamLoRASetup.py | 6 +-- .../modelSetup/HunyuanVideoFineTuneSetup.py | 6 +-- modules/modelSetup/HunyuanVideoLoRASetup.py | 6 +-- .../modelSetup/PixArtAlphaFineTuneSetup.py | 6 +-- modules/modelSetup/PixArtAlphaLoRASetup.py | 6 +-- modules/modelSetup/QwenFineTuneSetup.py | 6 +-- modules/modelSetup/QwenLoRASetup.py | 6 +-- modules/modelSetup/SanaFineTuneSetup.py | 6 +-- modules/modelSetup/SanaLoRASetup.py | 6 +-- .../StableDiffusion3FineTuneSetup.py | 6 +-- .../modelSetup/StableDiffusion3LoRASetup.py | 6 +-- modules/trainer/CloudTrainer.py | 1 + modules/ui/ModelTab.py | 48 +++++++++++++------ modules/ui/TrainingTab.py | 14 +++--- modules/util/ModelNames.py | 2 + modules/util/ModelWeightDtypes.py | 3 ++ modules/util/config/SampleConfig.py | 6 +-- modules/util/config/TrainConfig.py | 25 +++++++++- training_presets/#chroma Finetune 16GB.json | 2 +- training_presets/#chroma Finetune 24GB.json | 2 +- training_presets/#chroma Finetune 8GB.json | 2 +- training_presets/#chroma LoRA 16GB.json | 2 +- training_presets/#chroma LoRA 24GB.json | 2 +- training_presets/#chroma LoRA 8GB.json | 2 +- training_presets/#flux LoRA.json | 2 +- training_presets/#qwen Finetune 16GB.json | 2 +- training_presets/#qwen Finetune 24GB.json | 2 +- training_presets/#qwen LoRA 16GB.json | 2 +- training_presets/#qwen LoRA 24GB.json | 2 +- training_presets/#sd 3.json | 2 +- 56 files changed, 195 insertions(+), 144 deletions(-) diff --git a/modules/cloud/BaseCloud.py b/modules/cloud/BaseCloud.py index 9ca0a77a5..ec9a90015 100644 --- a/modules/cloud/BaseCloud.py +++ b/modules/cloud/BaseCloud.py @@ -43,6 +43,8 @@ def upload_config(self,commands : TrainCommands=None): self.file_sync.sync_up(local=Path(self.config.local_base_model_name),remote=Path(self.config.base_model_name)) if hasattr(self.config.prior,"local_model_name"): self.file_sync.sync_up(local=Path(self.config.prior.local_model_name),remote=Path(self.config.prior.model_name)) + if hasattr(self.config.transformer,"local_model_name"): + self.file_sync.sync_up(local=Path(self.config.transformer.local_model_name),remote=Path(self.config.transformer.model_name)) if hasattr(self.config,"local_lora_model_name"): self.file_sync.sync_up(local=Path(self.config.local_lora_model_name),remote=Path(self.config.lora_model_name)) diff --git a/modules/modelLoader/chroma/ChromaModelLoader.py b/modules/modelLoader/chroma/ChromaModelLoader.py index c66817933..c34a5f9ae 100644 --- a/modules/modelLoader/chroma/ChromaModelLoader.py +++ b/modules/modelLoader/chroma/ChromaModelLoader.py @@ -98,15 +98,15 @@ def __load_diffusers( transformer = ChromaTransformer2DModel.from_single_file( transformer_model_name, #avoid loading the transformer in float32: - torch_dtype = torch.bfloat16 if weight_dtypes.prior.torch_dtype() is None else weight_dtypes.prior.torch_dtype() + torch_dtype = torch.bfloat16 if weight_dtypes.transformer.torch_dtype() is None else weight_dtypes.transformer.torch_dtype() ) transformer = self._convert_diffusers_sub_module_to_dtype( - transformer, weight_dtypes.prior, weight_dtypes.train_dtype + transformer, weight_dtypes.transformer, weight_dtypes.train_dtype ) else: transformer = self._load_diffusers_sub_module( ChromaTransformer2DModel, - weight_dtypes.prior, + weight_dtypes.transformer, weight_dtypes.train_dtype, base_model_name, "transformer", @@ -142,7 +142,7 @@ def load( try: self.__load_internal( - model, model_type, weight_dtypes, model_names.base_model, model_names.prior_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, ) return except Exception: @@ -150,7 +150,7 @@ def load( try: self.__load_diffusers( - model, model_type, weight_dtypes, model_names.base_model, model_names.prior_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, ) return except Exception: @@ -158,7 +158,7 @@ def load( try: self.__load_safetensors( - model, model_type, weight_dtypes, model_names.base_model, model_names.prior_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, ) return except Exception: diff --git a/modules/modelLoader/flux/FluxModelLoader.py b/modules/modelLoader/flux/FluxModelLoader.py index 5e07da653..0301bf27d 100644 --- a/modules/modelLoader/flux/FluxModelLoader.py +++ b/modules/modelLoader/flux/FluxModelLoader.py @@ -134,15 +134,15 @@ def __load_diffusers( transformer = FluxTransformer2DModel.from_single_file( transformer_model_name, #avoid loading the transformer in float32: - torch_dtype = torch.bfloat16 if weight_dtypes.prior.torch_dtype() is None else weight_dtypes.prior.torch_dtype() + torch_dtype = torch.bfloat16 if weight_dtypes.transformer.torch_dtype() is None else weight_dtypes.transformer.torch_dtype() ) transformer = self._convert_diffusers_sub_module_to_dtype( - transformer, weight_dtypes.prior, weight_dtypes.train_dtype + transformer, weight_dtypes.transformer, weight_dtypes.train_dtype ) else: transformer = self._load_diffusers_sub_module( FluxTransformer2DModel, - weight_dtypes.prior, + weight_dtypes.transformer, weight_dtypes.train_dtype, base_model_name, "transformer", @@ -171,7 +171,7 @@ def __load_safetensors( transformer = FluxTransformer2DModel.from_single_file( #always load transformer separately even though FluxPipeLine.from_single_file() could load it, to avoid loading in float32: transformer_model_name if transformer_model_name else base_model_name, - torch_dtype = torch.bfloat16 if weight_dtypes.prior.torch_dtype() is None else weight_dtypes.prior.torch_dtype() + torch_dtype = torch.bfloat16 if weight_dtypes.transformer.torch_dtype() is None else weight_dtypes.transformer.torch_dtype() ) pipeline = FluxPipeline.from_single_file( pretrained_model_link_or_path=base_model_name, @@ -219,7 +219,7 @@ def __load_safetensors( print("text encoder 2 (t5) not loaded, continuing without it") transformer = self._convert_diffusers_sub_module_to_dtype( - pipeline.transformer, weight_dtypes.prior, weight_dtypes.train_dtype + pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype ) model.model_type = model_type @@ -242,7 +242,7 @@ def load( try: self.__load_internal( - model, model_type, weight_dtypes, model_names.base_model, model_names.prior_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, model_names.include_text_encoder, model_names.include_text_encoder_2, ) return @@ -251,7 +251,7 @@ def load( try: self.__load_diffusers( - model, model_type, weight_dtypes, model_names.base_model, model_names.prior_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, model_names.include_text_encoder, model_names.include_text_encoder_2, ) return @@ -260,7 +260,7 @@ def load( try: self.__load_safetensors( - model, model_type, weight_dtypes, model_names.base_model, model_names.prior_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, model_names.include_text_encoder, model_names.include_text_encoder_2, ) return diff --git a/modules/modelLoader/hiDream/HiDreamModelLoader.py b/modules/modelLoader/hiDream/HiDreamModelLoader.py index 79f509458..92bce0a9a 100644 --- a/modules/modelLoader/hiDream/HiDreamModelLoader.py +++ b/modules/modelLoader/hiDream/HiDreamModelLoader.py @@ -187,7 +187,7 @@ def __load_diffusers( transformer = self._load_diffusers_sub_module( HiDreamImageTransformer2DModel, - weight_dtypes.prior, + weight_dtypes.transformer, weight_dtypes.train_dtype, base_model_name, "transformer", @@ -264,7 +264,7 @@ def __load_safetensors( print("text encoder 2 (t5) not loaded, continuing without it") transformer = self._convert_diffusers_sub_module_to_dtype( - pipeline.transformer, weight_dtypes.prior, weight_dtypes.train_dtype + pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype ) model.model_type = model_type diff --git a/modules/modelLoader/hunyuanVideo/HunyuanVideoModelLoader.py b/modules/modelLoader/hunyuanVideo/HunyuanVideoModelLoader.py index 897a25b49..fe044ca03 100644 --- a/modules/modelLoader/hunyuanVideo/HunyuanVideoModelLoader.py +++ b/modules/modelLoader/hunyuanVideo/HunyuanVideoModelLoader.py @@ -129,7 +129,7 @@ def __load_diffusers( transformer = self._load_diffusers_sub_module( HunyuanVideoTransformer3DModel, - weight_dtypes.prior, + weight_dtypes.transformer, weight_dtypes.train_dtype, base_model_name, "transformer", @@ -192,7 +192,7 @@ def __load_safetensors( print("text encoder 2 (clip l) not loaded, continuing without it") transformer = self._convert_diffusers_sub_module_to_dtype( - pipeline.transformer, weight_dtypes.prior, weight_dtypes.train_dtype + pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype ) model.model_type = model_type diff --git a/modules/modelLoader/pixartAlpha/PixArtAlphaModelLoader.py b/modules/modelLoader/pixartAlpha/PixArtAlphaModelLoader.py index 13706b315..1428a4079 100644 --- a/modules/modelLoader/pixartAlpha/PixArtAlphaModelLoader.py +++ b/modules/modelLoader/pixartAlpha/PixArtAlphaModelLoader.py @@ -74,7 +74,7 @@ def __load_diffusers( transformer = self._load_diffusers_sub_module( Transformer2DModel, - weight_dtypes.prior, + weight_dtypes.transformer, weight_dtypes.train_dtype, base_model_name, "transformer", diff --git a/modules/modelLoader/qwen/QwenModelLoader.py b/modules/modelLoader/qwen/QwenModelLoader.py index 883af1fb2..55704ed12 100644 --- a/modules/modelLoader/qwen/QwenModelLoader.py +++ b/modules/modelLoader/qwen/QwenModelLoader.py @@ -100,15 +100,15 @@ def __load_diffusers( config=base_model_name, subfolder="transformer", #avoid loading the transformer in float32: - torch_dtype = torch.bfloat16 if weight_dtypes.prior.torch_dtype() is None else weight_dtypes.prior.torch_dtype() + torch_dtype = torch.bfloat16 if weight_dtypes.transformer.torch_dtype() is None else weight_dtypes.transformer.torch_dtype() ) transformer = self._convert_diffusers_sub_module_to_dtype( - transformer, weight_dtypes.prior, weight_dtypes.train_dtype + transformer, weight_dtypes.transformer, weight_dtypes.train_dtype ) else: transformer = self._load_diffusers_sub_module( QwenImageTransformer2DModel, - weight_dtypes.prior, + weight_dtypes.transformer, weight_dtypes.train_dtype, base_model_name, "transformer", @@ -144,7 +144,7 @@ def load( #TODO share code between models try: self.__load_internal( - model, model_type, weight_dtypes, model_names.base_model, model_names.prior_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, ) return except Exception: @@ -152,7 +152,7 @@ def load( #TODO share code between models try: self.__load_diffusers( - model, model_type, weight_dtypes, model_names.base_model, model_names.prior_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, ) return except Exception: @@ -160,7 +160,7 @@ def load( #TODO share code between models try: self.__load_safetensors( - model, model_type, weight_dtypes, model_names.base_model, model_names.prior_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, ) return except Exception: diff --git a/modules/modelLoader/sana/SanaModelLoader.py b/modules/modelLoader/sana/SanaModelLoader.py index c0b78e0d0..e4c2d638f 100644 --- a/modules/modelLoader/sana/SanaModelLoader.py +++ b/modules/modelLoader/sana/SanaModelLoader.py @@ -74,7 +74,7 @@ def __load_diffusers( transformer = self._load_diffusers_sub_module( SanaTransformer2DModel, - weight_dtypes.prior, + weight_dtypes.transformer, weight_dtypes.train_dtype, base_model_name, "transformer", diff --git a/modules/modelLoader/stableDiffusion3/StableDiffusion3ModelLoader.py b/modules/modelLoader/stableDiffusion3/StableDiffusion3ModelLoader.py index 5306a89d7..f5fa6871f 100644 --- a/modules/modelLoader/stableDiffusion3/StableDiffusion3ModelLoader.py +++ b/modules/modelLoader/stableDiffusion3/StableDiffusion3ModelLoader.py @@ -129,7 +129,7 @@ def __load_diffusers( transformer = self._load_diffusers_sub_module( SD3Transformer2DModel, - weight_dtypes.prior, + weight_dtypes.transformer, weight_dtypes.train_dtype, base_model_name, "transformer", @@ -220,7 +220,7 @@ def __load_safetensors( print("text encoder 3 (t5) not loaded, continuing without it") transformer = self._convert_diffusers_sub_module_to_dtype( - pipeline.transformer, weight_dtypes.prior, weight_dtypes.train_dtype + pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype ) model.model_type = model_type diff --git a/modules/modelSampler/FluxSampler.py b/modules/modelSampler/FluxSampler.py index ef085cac4..181df08fb 100644 --- a/modules/modelSampler/FluxSampler.py +++ b/modules/modelSampler/FluxSampler.py @@ -50,7 +50,7 @@ def __sample_base( noise_scheduler: NoiseScheduler, text_encoder_1_layer_skip: int = 0, text_encoder_2_layer_skip: int = 0, - prior_attention_mask: bool = False, + transformer_attention_mask: bool = False, on_update_progress: Callable[[int, int], None] = lambda _, __: None, ) -> ModelSamplerOutput: with self.model.autocast_context: @@ -75,7 +75,7 @@ def __sample_base( train_device=self.train_device, text_encoder_1_layer_skip=text_encoder_1_layer_skip, text_encoder_2_layer_skip=text_encoder_2_layer_skip, - apply_attention_mask=prior_attention_mask, + apply_attention_mask=transformer_attention_mask, ) self.model.text_encoder_to(self.temp_device) @@ -199,7 +199,7 @@ def __sample_inpainting( mask_image_path: str = "", text_encoder_1_layer_skip: int = 0, text_encoder_2_layer_skip: int = 0, - prior_attention_mask: bool = False, + transformer_attention_mask: bool = False, on_update_progress: Callable[[int, int], None] = lambda _, __: None, ) -> ModelSamplerOutput: with self.model.autocast_context: @@ -315,7 +315,7 @@ def __sample_inpainting( train_device=self.train_device, text_encoder_1_layer_skip=text_encoder_1_layer_skip, text_encoder_2_layer_skip=text_encoder_2_layer_skip, - apply_attention_mask=prior_attention_mask, + apply_attention_mask=transformer_attention_mask, ) self.model.text_encoder_to(self.temp_device) @@ -435,7 +435,7 @@ def sample( mask_image_path=sample_config.mask_image_path, text_encoder_1_layer_skip=sample_config.text_encoder_1_layer_skip, text_encoder_2_layer_skip=sample_config.text_encoder_2_layer_skip, - prior_attention_mask=sample_config.prior_attention_mask, + transformer_attention_mask=sample_config.transformer_attention_mask, on_update_progress=on_update_progress, ) else: @@ -451,7 +451,7 @@ def sample( noise_scheduler=sample_config.noise_scheduler, text_encoder_1_layer_skip=sample_config.text_encoder_1_layer_skip, text_encoder_2_layer_skip=sample_config.text_encoder_2_layer_skip, - prior_attention_mask=sample_config.prior_attention_mask, + transformer_attention_mask=sample_config.transformer_attention_mask, on_update_progress=on_update_progress, ) diff --git a/modules/modelSampler/HiDreamSampler.py b/modules/modelSampler/HiDreamSampler.py index ef7c0db27..3e8618d4e 100644 --- a/modules/modelSampler/HiDreamSampler.py +++ b/modules/modelSampler/HiDreamSampler.py @@ -45,7 +45,7 @@ def __sample_base( cfg_scale: float, noise_scheduler: NoiseScheduler, text_encoder_3_layer_skip: int = 0, - prior_attention_mask: bool = False, + transformer_attention_mask: bool = False, on_update_progress: Callable[[int, int], None] = lambda _, __: None, ) -> ModelSamplerOutput: with self.model.autocast_context: @@ -71,7 +71,7 @@ def __sample_base( text=prompt, train_device=self.train_device, text_encoder_3_layer_skip=text_encoder_3_layer_skip, - apply_attention_mask=prior_attention_mask, + apply_attention_mask=transformer_attention_mask, )) negative_text_encoder_3_prompt_embedding, negative_text_encoder_4_prompt_embedding, negative_pooled_prompt_embedding = \ @@ -80,7 +80,7 @@ def __sample_base( text=negative_prompt, train_device=self.train_device, text_encoder_3_layer_skip=text_encoder_3_layer_skip, - apply_attention_mask=prior_attention_mask, + apply_attention_mask=transformer_attention_mask, )) combined_text_encoder_3_prompt_embedding = torch.cat( @@ -181,7 +181,7 @@ def sample( cfg_scale=sample_config.cfg_scale, noise_scheduler=sample_config.noise_scheduler, text_encoder_3_layer_skip=sample_config.text_encoder_3_layer_skip, - prior_attention_mask=sample_config.prior_attention_mask, + transformer_attention_mask=sample_config.transformer_attention_mask, on_update_progress=on_update_progress, ) diff --git a/modules/modelSampler/HunyuanVideoSampler.py b/modules/modelSampler/HunyuanVideoSampler.py index 5e1b4ddde..96f93eade 100644 --- a/modules/modelSampler/HunyuanVideoSampler.py +++ b/modules/modelSampler/HunyuanVideoSampler.py @@ -48,7 +48,7 @@ def __sample_base( noise_scheduler: NoiseScheduler, text_encoder_1_layer_skip: int = 0, text_encoder_2_layer_skip: int = 0, - prior_attention_mask: bool = False, + transformer_attention_mask: bool = False, on_update_progress: Callable[[int, int], None] = lambda _, __: None, ) -> ModelSamplerOutput: with self.model.autocast_context: @@ -195,7 +195,7 @@ def sample( noise_scheduler=sample_config.noise_scheduler, text_encoder_1_layer_skip=sample_config.text_encoder_1_layer_skip, text_encoder_2_layer_skip=sample_config.text_encoder_2_layer_skip, - prior_attention_mask=sample_config.prior_attention_mask, + transformer_attention_mask=sample_config.transformer_attention_mask, on_update_progress=on_update_progress, ) diff --git a/modules/modelSampler/StableDiffusion3Sampler.py b/modules/modelSampler/StableDiffusion3Sampler.py index c983acc13..51c687caa 100644 --- a/modules/modelSampler/StableDiffusion3Sampler.py +++ b/modules/modelSampler/StableDiffusion3Sampler.py @@ -47,7 +47,7 @@ def __sample_base( text_encoder_1_layer_skip: int = 0, text_encoder_2_layer_skip: int = 0, text_encoder_3_layer_skip: int = 0, - prior_attention_mask: bool = False, + transformer_attention_mask: bool = False, on_update_progress: Callable[[int, int], None] = lambda _, __: None, ) -> ModelSamplerOutput: with self.model.autocast_context: @@ -73,7 +73,7 @@ def __sample_base( text_encoder_1_layer_skip=text_encoder_1_layer_skip, text_encoder_2_layer_skip=text_encoder_2_layer_skip, text_encoder_3_layer_skip=text_encoder_3_layer_skip, - apply_attention_mask=prior_attention_mask, + apply_attention_mask=transformer_attention_mask, )) negative_prompt_embedding, negative_pooled_prompt_embedding = self.model.combine_text_encoder_output( @@ -83,7 +83,7 @@ def __sample_base( text_encoder_1_layer_skip=text_encoder_1_layer_skip, text_encoder_2_layer_skip=text_encoder_2_layer_skip, text_encoder_3_layer_skip=text_encoder_3_layer_skip, - apply_attention_mask=prior_attention_mask, + apply_attention_mask=transformer_attention_mask, )) combined_prompt_embedding = torch.cat([negative_prompt_embedding, prompt_embedding], dim=0) @@ -180,7 +180,7 @@ def sample( text_encoder_1_layer_skip=sample_config.text_encoder_1_layer_skip, text_encoder_2_layer_skip=sample_config.text_encoder_2_layer_skip, text_encoder_3_layer_skip=sample_config.text_encoder_3_layer_skip, - prior_attention_mask=sample_config.prior_attention_mask, + transformer_attention_mask=sample_config.transformer_attention_mask, on_update_progress=on_update_progress, ) diff --git a/modules/modelSetup/BaseChromaSetup.py b/modules/modelSetup/BaseChromaSetup.py index 6a9342b41..183ea3a2f 100644 --- a/modules/modelSetup/BaseChromaSetup.py +++ b/modules/modelSetup/BaseChromaSetup.py @@ -62,7 +62,7 @@ def setup_optimizations( apply_circular_padding_to_conv2d(model.transformer_lora) model.autocast_context, model.train_dtype = create_autocast_context(self.train_device, config.train_dtype, [ - config.weight_dtypes().prior, + config.weight_dtypes().transformer, config.weight_dtypes().text_encoder, config.weight_dtypes().vae, config.weight_dtypes().lora if config.training_method == TrainingMethod.LORA else None, diff --git a/modules/modelSetup/BaseFluxSetup.py b/modules/modelSetup/BaseFluxSetup.py index b53f638f6..64b33093f 100644 --- a/modules/modelSetup/BaseFluxSetup.py +++ b/modules/modelSetup/BaseFluxSetup.py @@ -63,7 +63,7 @@ def setup_optimizations( apply_circular_padding_to_conv2d(model.transformer_lora) model.autocast_context, model.train_dtype = create_autocast_context(self.train_device, config.train_dtype, [ - config.weight_dtypes().prior, + config.weight_dtypes().transformer, config.weight_dtypes().text_encoder, config.weight_dtypes().text_encoder_2, config.weight_dtypes().vae, @@ -230,7 +230,7 @@ def predict( if 'text_encoder_2_hidden_state' in batch and not config.train_text_encoder_2_or_embedding() else None, text_encoder_1_dropout_probability=config.text_encoder.dropout_probability, text_encoder_2_dropout_probability=config.text_encoder_2.dropout_probability, - apply_attention_mask=config.prior.attention_mask, + apply_attention_mask=config.transformer.attention_mask, ) latent_image = batch['latent_image'] @@ -268,7 +268,7 @@ def predict( latent_input = scaled_noisy_latent_image if model.transformer.config.guidance_embeds: - guidance = torch.tensor([config.prior.guidance_scale], device=self.train_device) + guidance = torch.tensor([config.transformer.guidance_scale], device=self.train_device) guidance = guidance.expand(latent_input.shape[0]) else: guidance = None diff --git a/modules/modelSetup/BaseHiDreamSetup.py b/modules/modelSetup/BaseHiDreamSetup.py index a9a6da8af..f98ecec43 100644 --- a/modules/modelSetup/BaseHiDreamSetup.py +++ b/modules/modelSetup/BaseHiDreamSetup.py @@ -61,7 +61,7 @@ def setup_optimizations( enable_checkpointing_for_llama_encoder_layers(model.text_encoder_4, config) model.autocast_context, model.train_dtype = create_autocast_context(self.train_device, config.train_dtype, [ - config.weight_dtypes().prior, + config.weight_dtypes().transformer, config.weight_dtypes().text_encoder, config.weight_dtypes().text_encoder_2, config.weight_dtypes().text_encoder_3, @@ -90,7 +90,7 @@ def setup_optimizations( config.train_dtype, config.fallback_train_dtype, [ - config.weight_dtypes().prior, + config.weight_dtypes().transformer, config.weight_dtypes().lora if config.training_method == TrainingMethod.LORA else None, config.weight_dtypes().embedding if config.train_any_embedding() else None, ], @@ -326,7 +326,7 @@ def predict( text_encoder_2_dropout_probability=config.text_encoder_2.dropout_probability, text_encoder_3_dropout_probability=config.text_encoder_3.dropout_probability, text_encoder_4_dropout_probability=config.text_encoder_4.dropout_probability, - apply_attention_mask=config.prior.attention_mask, + apply_attention_mask=config.transformer.attention_mask, )) latent_image = batch['latent_image'] diff --git a/modules/modelSetup/BaseHunyuanVideoSetup.py b/modules/modelSetup/BaseHunyuanVideoSetup.py index 786b89380..70db6806a 100644 --- a/modules/modelSetup/BaseHunyuanVideoSetup.py +++ b/modules/modelSetup/BaseHunyuanVideoSetup.py @@ -63,7 +63,7 @@ def setup_optimizations( apply_circular_padding_to_conv2d(model.transformer_lora) model.autocast_context, model.train_dtype = create_autocast_context(self.train_device, config.train_dtype, [ - config.weight_dtypes().prior, + config.weight_dtypes().transformer, config.weight_dtypes().text_encoder, config.weight_dtypes().text_encoder_2, config.weight_dtypes().vae, @@ -77,7 +77,7 @@ def setup_optimizations( config.train_dtype, config.fallback_train_dtype, [ - config.weight_dtypes().prior, + config.weight_dtypes().transformer, config.weight_dtypes().lora if config.training_method == TrainingMethod.LORA else None, config.weight_dtypes().embedding if config.train_any_embedding() else None, ], @@ -259,7 +259,7 @@ def predict( latent_input = scaled_noisy_latent_image if model.transformer.config.guidance_embeds: - guidance = torch.tensor([config.prior.guidance_scale * 1000.0], device=self.train_device) + guidance = torch.tensor([config.transformer.guidance_scale * 1000.0], device=self.train_device) guidance = guidance.expand(latent_input.shape[0]) else: guidance = None diff --git a/modules/modelSetup/BasePixArtAlphaSetup.py b/modules/modelSetup/BasePixArtAlphaSetup.py index 3e8ad0a38..754a1fd5c 100644 --- a/modules/modelSetup/BasePixArtAlphaSetup.py +++ b/modules/modelSetup/BasePixArtAlphaSetup.py @@ -62,7 +62,7 @@ def setup_optimizations( apply_circular_padding_to_conv2d(model.transformer_lora) model.autocast_context, model.train_dtype = create_autocast_context(self.train_device, config.train_dtype, [ - config.weight_dtypes().prior, + config.weight_dtypes().transformer, config.weight_dtypes().text_encoder, config.weight_dtypes().vae, config.weight_dtypes().lora if config.training_method == TrainingMethod.LORA else None, diff --git a/modules/modelSetup/BaseQwenSetup.py b/modules/modelSetup/BaseQwenSetup.py index c8dcd4913..93984fd95 100644 --- a/modules/modelSetup/BaseQwenSetup.py +++ b/modules/modelSetup/BaseQwenSetup.py @@ -59,7 +59,7 @@ def setup_optimizations( apply_circular_padding_to_conv2d(model.transformer_lora) model.autocast_context, model.train_dtype = create_autocast_context(self.train_device, config.train_dtype, [ - config.weight_dtypes().prior, + config.weight_dtypes().transformer, config.weight_dtypes().text_encoder, config.weight_dtypes().vae, config.weight_dtypes().lora if config.training_method == TrainingMethod.LORA else None, diff --git a/modules/modelSetup/BaseSanaSetup.py b/modules/modelSetup/BaseSanaSetup.py index 29b810774..ee58332d3 100644 --- a/modules/modelSetup/BaseSanaSetup.py +++ b/modules/modelSetup/BaseSanaSetup.py @@ -64,7 +64,7 @@ def setup_optimizations( apply_circular_padding_to_conv2d(model.transformer_lora) model.autocast_context, model.train_dtype = create_autocast_context(self.train_device, config.train_dtype, [ - config.weight_dtypes().prior, + config.weight_dtypes().transformer, config.weight_dtypes().text_encoder, config.weight_dtypes().vae, config.weight_dtypes().lora if config.training_method == TrainingMethod.LORA else None, diff --git a/modules/modelSetup/BaseStableDiffusion3Setup.py b/modules/modelSetup/BaseStableDiffusion3Setup.py index 4f2888554..4d20536cf 100644 --- a/modules/modelSetup/BaseStableDiffusion3Setup.py +++ b/modules/modelSetup/BaseStableDiffusion3Setup.py @@ -63,7 +63,7 @@ def setup_optimizations( apply_circular_padding_to_conv2d(model.transformer_lora) model.autocast_context, model.train_dtype = create_autocast_context(self.train_device, config.train_dtype, [ - config.weight_dtypes().prior, + config.weight_dtypes().transformer, config.weight_dtypes().text_encoder, config.weight_dtypes().text_encoder_2, config.weight_dtypes().text_encoder_3, @@ -284,7 +284,7 @@ def predict( text_encoder_1_dropout_probability=config.text_encoder.dropout_probability, text_encoder_2_dropout_probability=config.text_encoder_2.dropout_probability, text_encoder_3_dropout_probability=config.text_encoder_3.dropout_probability, - apply_attention_mask=config.prior.attention_mask, + apply_attention_mask=config.transformer.attention_mask, )) latent_image = batch['latent_image'] diff --git a/modules/modelSetup/ChromaFineTuneSetup.py b/modules/modelSetup/ChromaFineTuneSetup.py index 0df444426..82897b89a 100644 --- a/modules/modelSetup/ChromaFineTuneSetup.py +++ b/modules/modelSetup/ChromaFineTuneSetup.py @@ -40,7 +40,7 @@ def create_parameters( "embeddings" ) - self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.prior, freeze=ModuleFilter.create(config), debug=config.debug_mode) + self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.transformer, freeze=ModuleFilter.create(config), debug=config.debug_mode) return parameter_group_collection @@ -52,7 +52,7 @@ def __setup_requires_grad( self._setup_embeddings_requires_grad(model, config) self._setup_model_part_requires_grad("text_encoder", model.text_encoder, config.text_encoder, model.train_progress) - self._setup_model_part_requires_grad("transformer", model.transformer, config.prior, model.train_progress) + self._setup_model_part_requires_grad("transformer", model.transformer, config.transformer, model.train_progress) model.vae.requires_grad_(False) @@ -95,7 +95,7 @@ def setup_train_device( model.vae.eval() - if config.prior.train: + if config.transformer.train: model.transformer.train() else: model.transformer.eval() diff --git a/modules/modelSetup/ChromaLoRASetup.py b/modules/modelSetup/ChromaLoRASetup.py index fbe978386..9802c7b00 100644 --- a/modules/modelSetup/ChromaLoRASetup.py +++ b/modules/modelSetup/ChromaLoRASetup.py @@ -41,7 +41,7 @@ def create_parameters( "embeddings" ) - self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.prior) + self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.transformer) return parameter_group_collection @@ -57,7 +57,7 @@ def __setup_requires_grad( model.vae.requires_grad_(False) self._setup_model_part_requires_grad("text_encoder_lora", model.text_encoder_lora, config.text_encoder, model.train_progress) - self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.prior, model.train_progress) + self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.transformer, model.train_progress) def setup_model( self, @@ -123,7 +123,7 @@ def setup_train_device( model.vae.eval() - if config.prior.train: + if config.transformer.train: model.transformer.train() else: model.transformer.eval() diff --git a/modules/modelSetup/FluxFineTuneSetup.py b/modules/modelSetup/FluxFineTuneSetup.py index ae495ef72..55ce6e4e7 100644 --- a/modules/modelSetup/FluxFineTuneSetup.py +++ b/modules/modelSetup/FluxFineTuneSetup.py @@ -47,7 +47,7 @@ def create_parameters( "embeddings_2" ) - self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.prior, + self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.transformer, freeze=ModuleFilter.create(config), debug=config.debug_mode) return parameter_group_collection @@ -60,7 +60,7 @@ def __setup_requires_grad( self._setup_model_part_requires_grad("text_encoder_1", model.text_encoder_1, config.text_encoder, model.train_progress) self._setup_model_part_requires_grad("text_encoder_2", model.text_encoder_2, config.text_encoder_2, model.train_progress) - self._setup_model_part_requires_grad("transformer", model.transformer, config.prior, model.train_progress) + self._setup_model_part_requires_grad("transformer", model.transformer, config.transformer, model.train_progress) model.vae.requires_grad_(False) @@ -117,7 +117,7 @@ def setup_train_device( model.vae.eval() - if config.prior.train: + if config.transformer.train: model.transformer.train() else: model.transformer.eval() diff --git a/modules/modelSetup/FluxLoRASetup.py b/modules/modelSetup/FluxLoRASetup.py index 59ee3e51f..0b58af8f4 100644 --- a/modules/modelSetup/FluxLoRASetup.py +++ b/modules/modelSetup/FluxLoRASetup.py @@ -48,7 +48,7 @@ def create_parameters( "embeddings_2" ) - self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.prior) + self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.transformer) return parameter_group_collection def __setup_requires_grad( @@ -66,7 +66,7 @@ def __setup_requires_grad( self._setup_model_part_requires_grad("text_encoder_1_lora", model.text_encoder_1_lora, config.text_encoder, model.train_progress) self._setup_model_part_requires_grad("text_encoder_2_lora", model.text_encoder_2_lora, config.text_encoder_2, model.train_progress) - self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.prior, model.train_progress) + self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.transformer, model.train_progress) def setup_model( self, @@ -159,7 +159,7 @@ def setup_train_device( model.vae.eval() - if config.prior.train: + if config.transformer.train: model.transformer.train() else: model.transformer.eval() diff --git a/modules/modelSetup/HiDreamFineTuneSetup.py b/modules/modelSetup/HiDreamFineTuneSetup.py index c5d5f3216..b5ccf4698 100644 --- a/modules/modelSetup/HiDreamFineTuneSetup.py +++ b/modules/modelSetup/HiDreamFineTuneSetup.py @@ -63,7 +63,7 @@ def create_parameters( "embeddings_4" ) - self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.prior) + self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.transformer) return parameter_group_collection @@ -78,7 +78,7 @@ def __setup_requires_grad( self._setup_model_part_requires_grad("text_encoder_2", model.text_encoder_2, config.text_encoder_2, model.train_progress) self._setup_model_part_requires_grad("text_encoder_3", model.text_encoder_3, config.text_encoder_3, model.train_progress) self._setup_model_part_requires_grad("text_encoder_4", model.text_encoder_4, config.text_encoder_4, model.train_progress) - self._setup_model_part_requires_grad("transformer", model.transformer, config.prior, model.train_progress, + self._setup_model_part_requires_grad("transformer", model.transformer, config.transformer, model.train_progress, freeze=ModuleFilter.create(config), debug=config.debug_mode) model.vae.requires_grad_(False) @@ -164,7 +164,7 @@ def setup_train_device( model.vae.eval() - if config.prior.train: + if config.transformer.train: model.transformer.train() else: model.transformer.eval() diff --git a/modules/modelSetup/HiDreamLoRASetup.py b/modules/modelSetup/HiDreamLoRASetup.py index 3cc92b5f2..5e4fded42 100644 --- a/modules/modelSetup/HiDreamLoRASetup.py +++ b/modules/modelSetup/HiDreamLoRASetup.py @@ -64,7 +64,7 @@ def create_parameters( "embeddings_4" ) - self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.prior) + self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.transformer) return parameter_group_collection @@ -89,7 +89,7 @@ def __setup_requires_grad( self._setup_model_part_requires_grad("text_encoder_2_lora", model.text_encoder_2_lora, config.text_encoder_2, model.train_progress) self._setup_model_part_requires_grad("text_encoder_3_lora", model.text_encoder_3_lora, config.text_encoder_3, model.train_progress) self._setup_model_part_requires_grad("text_encoder_4_lora", model.text_encoder_4_lora, config.text_encoder_4, model.train_progress) - self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.prior, model.train_progress) + self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.transformer, model.train_progress) def setup_model( self, @@ -236,7 +236,7 @@ def setup_train_device( model.vae.eval() - if config.prior.train: + if config.transformer.train: model.transformer.train() else: model.transformer.eval() diff --git a/modules/modelSetup/HunyuanVideoFineTuneSetup.py b/modules/modelSetup/HunyuanVideoFineTuneSetup.py index 55fa89378..a1057f748 100644 --- a/modules/modelSetup/HunyuanVideoFineTuneSetup.py +++ b/modules/modelSetup/HunyuanVideoFineTuneSetup.py @@ -48,7 +48,7 @@ def create_parameters( "embeddings_2" ) - self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.prior, + self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.transformer, freeze=ModuleFilter.create(config), debug=config.debug_mode) return parameter_group_collection @@ -62,7 +62,7 @@ def __setup_requires_grad( self._setup_model_part_requires_grad("text_encoder_1", model.text_encoder_1, config.text_encoder, model.train_progress) self._setup_model_part_requires_grad("text_encoder_2", model.text_encoder_2, config.text_encoder_2, model.train_progress) - self._setup_model_part_requires_grad("transformer", model.transformer, config.prior, model.train_progress) + self._setup_model_part_requires_grad("transformer", model.transformer, config.transformer, model.train_progress) model.vae.requires_grad_(False) @@ -117,7 +117,7 @@ def setup_train_device( model.vae.eval() - if config.prior.train: + if config.transformer.train: model.transformer.train() else: model.transformer.eval() diff --git a/modules/modelSetup/HunyuanVideoLoRASetup.py b/modules/modelSetup/HunyuanVideoLoRASetup.py index 8bb1705ea..2652ba484 100644 --- a/modules/modelSetup/HunyuanVideoLoRASetup.py +++ b/modules/modelSetup/HunyuanVideoLoRASetup.py @@ -50,7 +50,7 @@ def create_parameters( "embeddings_2" ) - self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.prior) + self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.transformer) return parameter_group_collection @@ -69,7 +69,7 @@ def __setup_requires_grad( self._setup_model_part_requires_grad("text_encoder_1_lora", model.text_encoder_1_lora, config.text_encoder, model.train_progress) self._setup_model_part_requires_grad("text_encoder_2_lora", model.text_encoder_2_lora, config.text_encoder_2, model.train_progress) - self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.prior, model.train_progress) + self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.transformer, model.train_progress) def setup_model( self, @@ -162,7 +162,7 @@ def setup_train_device( model.vae.eval() - if config.prior.train: + if config.transformer.train: model.transformer.train() else: model.transformer.eval() diff --git a/modules/modelSetup/PixArtAlphaFineTuneSetup.py b/modules/modelSetup/PixArtAlphaFineTuneSetup.py index 71559f03e..4ad4911d3 100644 --- a/modules/modelSetup/PixArtAlphaFineTuneSetup.py +++ b/modules/modelSetup/PixArtAlphaFineTuneSetup.py @@ -39,7 +39,7 @@ def create_parameters( "embeddings" ) - self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.prior, + self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.transformer, freeze=ModuleFilter.create(config), debug=config.debug_mode) return parameter_group_collection @@ -58,7 +58,7 @@ def __setup_requires_grad( model.train_progress, i) embedding.text_encoder_vector.requires_grad_(train_embedding) - self._setup_model_part_requires_grad("transformer", model.transformer, config.prior, model.train_progress) + self._setup_model_part_requires_grad("transformer", model.transformer, config.transformer, model.train_progress) model.vae.requires_grad_(False) @@ -99,7 +99,7 @@ def setup_train_device( model.vae.eval() - if config.prior.train: + if config.transformer.train: model.transformer.train() else: model.transformer.eval() diff --git a/modules/modelSetup/PixArtAlphaLoRASetup.py b/modules/modelSetup/PixArtAlphaLoRASetup.py index 2e896e8f7..451a43b13 100644 --- a/modules/modelSetup/PixArtAlphaLoRASetup.py +++ b/modules/modelSetup/PixArtAlphaLoRASetup.py @@ -40,7 +40,7 @@ def create_parameters( "embeddings" ) - self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.prior) + self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.transformer) return parameter_group_collection @@ -55,7 +55,7 @@ def __setup_requires_grad( model.vae.requires_grad_(False) self._setup_model_part_requires_grad("text_encoder_lora", model.text_encoder_lora, config.text_encoder, model.train_progress) - self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.prior, model.train_progress) + self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.transformer, model.train_progress) def setup_model( self, @@ -119,7 +119,7 @@ def setup_train_device( model.vae.eval() - if config.prior.train: + if config.transformer.train: model.transformer.train() else: model.transformer.eval() diff --git a/modules/modelSetup/QwenFineTuneSetup.py b/modules/modelSetup/QwenFineTuneSetup.py index 3135b3139..2d305a1be 100644 --- a/modules/modelSetup/QwenFineTuneSetup.py +++ b/modules/modelSetup/QwenFineTuneSetup.py @@ -32,7 +32,7 @@ def create_parameters( parameter_group_collection = NamedParameterGroupCollection() self._create_model_part_parameters(parameter_group_collection, "text_encoder", model.text_encoder, config.text_encoder) - self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.prior, freeze=ModuleFilter.create(config), debug=config.debug_mode) + self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.transformer, freeze=ModuleFilter.create(config), debug=config.debug_mode) if config.train_any_embedding() or config.train_any_output_embedding(): raise NotImplementedError("Embeddings not implemented for Qwen") @@ -45,7 +45,7 @@ def __setup_requires_grad( config: TrainConfig, ): self._setup_model_part_requires_grad("text_encoder", model.text_encoder, config.text_encoder, model.train_progress) - self._setup_model_part_requires_grad("transformer", model.transformer, config.prior, model.train_progress) + self._setup_model_part_requires_grad("transformer", model.transformer, config.transformer, model.train_progress) model.vae.requires_grad_(False) @@ -81,7 +81,7 @@ def setup_train_device( model.vae.eval() - if config.prior.train: + if config.transformer.train: model.transformer.train() else: model.transformer.eval() diff --git a/modules/modelSetup/QwenLoRASetup.py b/modules/modelSetup/QwenLoRASetup.py index 10fd279b9..644cd8656 100644 --- a/modules/modelSetup/QwenLoRASetup.py +++ b/modules/modelSetup/QwenLoRASetup.py @@ -33,7 +33,7 @@ def create_parameters( parameter_group_collection = NamedParameterGroupCollection() self._create_model_part_parameters(parameter_group_collection, "text_encoder", model.text_encoder_lora, config.text_encoder) - self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer_lora, config.prior) + self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer_lora, config.transformer) if config.train_any_embedding() or config.train_any_output_embedding(): raise NotImplementedError("Embeddings not implemented for Qwen") @@ -51,7 +51,7 @@ def __setup_requires_grad( model.vae.requires_grad_(False) self._setup_model_part_requires_grad("text_encoder", model.text_encoder_lora, config.text_encoder, model.train_progress) - self._setup_model_part_requires_grad("transformer", model.transformer_lora, config.prior, model.train_progress) + self._setup_model_part_requires_grad("transformer", model.transformer_lora, config.transformer, model.train_progress) def setup_model( self, @@ -110,7 +110,7 @@ def setup_train_device( model.vae.eval() - if config.prior.train: + if config.transformer.train: model.transformer.train() else: model.transformer.eval() diff --git a/modules/modelSetup/SanaFineTuneSetup.py b/modules/modelSetup/SanaFineTuneSetup.py index 683e77924..110d442e4 100644 --- a/modules/modelSetup/SanaFineTuneSetup.py +++ b/modules/modelSetup/SanaFineTuneSetup.py @@ -39,7 +39,7 @@ def create_parameters( "embeddings" ) - self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.prior, + self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.transformer, freeze=ModuleFilter.create(config), debug=config.debug_mode) return parameter_group_collection @@ -52,7 +52,7 @@ def __setup_requires_grad( self._setup_embeddings_requires_grad(model, config) self._setup_model_part_requires_grad("text_encoder", model.text_encoder, config.text_encoder, model.train_progress) - self._setup_model_part_requires_grad("transformer", model.transformer, config.prior, model.train_progress) + self._setup_model_part_requires_grad("transformer", model.transformer, config.transformer, model.train_progress) model.vae.requires_grad_(False) @@ -93,7 +93,7 @@ def setup_train_device( model.vae.eval() - if config.prior.train: + if config.transformer.train: model.transformer.train() else: model.transformer.eval() diff --git a/modules/modelSetup/SanaLoRASetup.py b/modules/modelSetup/SanaLoRASetup.py index bebd51a0d..0fd7c04c0 100644 --- a/modules/modelSetup/SanaLoRASetup.py +++ b/modules/modelSetup/SanaLoRASetup.py @@ -40,7 +40,7 @@ def create_parameters( "embeddings" ) - self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.prior) + self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.transformer) return parameter_group_collection @@ -55,7 +55,7 @@ def __setup_requires_grad( model.vae.requires_grad_(False) self._setup_model_part_requires_grad("text_encoder_lora", model.text_encoder_lora, config.text_encoder, model.train_progress) - self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.prior, model.train_progress) + self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.transformer, model.train_progress) def setup_model( self, @@ -119,7 +119,7 @@ def setup_train_device( model.vae.eval() - if config.prior.train: + if config.transformer.train: model.transformer.train() else: model.transformer.eval() diff --git a/modules/modelSetup/StableDiffusion3FineTuneSetup.py b/modules/modelSetup/StableDiffusion3FineTuneSetup.py index 7857acafc..d575a1f25 100644 --- a/modules/modelSetup/StableDiffusion3FineTuneSetup.py +++ b/modules/modelSetup/StableDiffusion3FineTuneSetup.py @@ -54,7 +54,7 @@ def create_parameters( "embeddings_3" ) - self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.prior, + self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.transformer, freeze=ModuleFilter.create(config), debug=config.debug_mode) return parameter_group_collection @@ -69,7 +69,7 @@ def __setup_requires_grad( self._setup_model_part_requires_grad("text_encoder_1", model.text_encoder_1, config.text_encoder, model.train_progress) self._setup_model_part_requires_grad("text_encoder_2", model.text_encoder_2, config.text_encoder_2, model.train_progress) self._setup_model_part_requires_grad("text_encoder_3", model.text_encoder_3, config.text_encoder_3, model.train_progress) - self._setup_model_part_requires_grad("transformer", model.transformer, config.prior, model.train_progress) + self._setup_model_part_requires_grad("transformer", model.transformer, config.transformer, model.train_progress) model.vae.requires_grad_(False) @@ -139,7 +139,7 @@ def setup_train_device( model.vae.eval() - if config.prior.train: + if config.transformer.train: model.transformer.train() else: model.transformer.eval() diff --git a/modules/modelSetup/StableDiffusion3LoRASetup.py b/modules/modelSetup/StableDiffusion3LoRASetup.py index 82da67efd..7adabac36 100644 --- a/modules/modelSetup/StableDiffusion3LoRASetup.py +++ b/modules/modelSetup/StableDiffusion3LoRASetup.py @@ -55,7 +55,7 @@ def create_parameters( "embeddings_3" ) - self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.prior) + self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.transformer) return parameter_group_collection @@ -77,7 +77,7 @@ def __setup_requires_grad( self._setup_model_part_requires_grad("text_encoder_1_lora", model.text_encoder_1_lora, config.text_encoder, model.train_progress) self._setup_model_part_requires_grad("text_encoder_2_lora", model.text_encoder_2_lora, config.text_encoder_2, model.train_progress) self._setup_model_part_requires_grad("text_encoder_3_lora", model.text_encoder_3_lora, config.text_encoder_3, model.train_progress) - self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.prior, model.train_progress) + self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.transformer, model.train_progress) def setup_model( self, @@ -197,7 +197,7 @@ def setup_train_device( model.vae.eval() - if config.prior.train: + if config.transformer.train: model.transformer.train() else: model.transformer.eval() diff --git a/modules/trainer/CloudTrainer.py b/modules/trainer/CloudTrainer.py index 20426c01a..daa595892 100644 --- a/modules/trainer/CloudTrainer.py +++ b/modules/trainer/CloudTrainer.py @@ -161,6 +161,7 @@ def adjust(config, attribute: str, if_exists: bool=False): adjust(remote,"cache_dir") adjust(remote,"base_model_name", if_exists=True) adjust(remote.prior,"model_name", if_exists=True) + adjust(remote.transformer,"model_name", if_exists=True) adjust(remote,"output_model_destination") adjust(remote,"lora_model_name") diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index d27118bb3..d9ea2fa8f 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -87,7 +87,7 @@ def __setup_stable_diffusion_3_ui(self): row = self.__create_base_dtype_components(row) row = self.__create_base_components( row, - has_prior=True, + has_transformer=True, has_text_encoder_1=True, has_text_encoder_2=True, has_text_encoder_3=True, @@ -105,8 +105,8 @@ def __setup_flux_ui(self): row = self.__create_base_dtype_components(row) row = self.__create_base_components( row, - has_prior=True, - allow_override_prior=True, + has_transformer=True, + allow_override_transformer=True, has_text_encoder_1=True, has_text_encoder_2=True, has_vae=True, @@ -123,8 +123,8 @@ def __setup_chroma_ui(self): row = self.__create_base_dtype_components(row) row = self.__create_base_components( row, - has_prior=True, - allow_override_prior=True, + has_transformer=True, + allow_override_transformer=True, has_text_encoder_1=True, has_vae=True, ) @@ -140,8 +140,8 @@ def __setup_qwen_ui(self): row = self.__create_base_dtype_components(row) row = self.__create_base_components( row, - has_prior=True, - allow_override_prior=True, + has_transformer=True, + allow_override_transformer=True, has_text_encoder_1=True, has_vae=True, ) @@ -193,7 +193,7 @@ def __setup_pixart_alpha_ui(self): row = self.__create_base_dtype_components(row) row = self.__create_base_components( row, - has_prior=True, + has_transformer=True, has_text_encoder=True, has_vae=True, ) @@ -209,7 +209,7 @@ def __setup_sana_ui(self): row = self.__create_base_dtype_components(row) row = self.__create_base_components( row, - has_prior=True, + has_transformer=True, has_text_encoder=True, has_vae=True, ) @@ -225,7 +225,7 @@ def __setup_hunyuan_video_ui(self): row = self.__create_base_dtype_components(row) row = self.__create_base_components( row, - has_prior=True, + has_transformer=True, has_text_encoder_1=True, has_text_encoder_2=True, has_vae=True, @@ -242,7 +242,7 @@ def __setup_hi_dream_ui(self): row = self.__create_base_dtype_components(row) row = self.__create_base_components( row, - has_prior=True, + has_transformer=True, has_text_encoder_1=True, has_text_encoder_2=True, has_text_encoder_3=True, @@ -306,6 +306,8 @@ def __create_base_components( has_unet: bool = False, has_prior: bool = False, allow_override_prior: bool = False, + has_transformer: bool = False, + allow_override_transformer: bool = False, allow_override_text_encoder_4: bool = False, has_text_encoder: bool = False, has_text_encoder_1: bool = False, @@ -326,8 +328,8 @@ def __create_base_components( if has_prior: if allow_override_prior: # prior model - components.label(self.scroll_frame, row, 0, "Prior Model", #TODO rename; as more models support this, the name "Prior" gets confusing - tooltip="Filename, directory or Hugging Face repository of the prior model. For Flux, Chroma or Qwen, it must be a safetensors file that only contains the transformer.") + components.label(self.scroll_frame, row, 0, "Prior Model", + tooltip="Filename, directory or Hugging Face repository of the prior model") components.file_entry( self.scroll_frame, row, 1, self.ui_state, "prior.model_name", path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x @@ -341,6 +343,24 @@ def __create_base_components( row += 1 + if has_transformer: + if allow_override_transformer: + # transformer model + components.label(self.scroll_frame, row, 0, "Override Transformer", + tooltip="Can be used to override the transformer in the base model. Local safetensors files and files on HuggingFace are supported.") + components.file_entry( + self.scroll_frame, row, 1, self.ui_state, "transformer.model_name", + path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x + ) + + # transformer weight dtype + components.label(self.scroll_frame, row, 3, "Override Transformer Data Type", + tooltip="Overrides the transformer weight data type") + components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + self.ui_state, "transformer.weight_dtype") + + row += 1 + if has_text_encoder: # text encoder weight dtype components.label(self.scroll_frame, row, 3, "Override Text Encoder Data Type", @@ -379,7 +399,7 @@ def __create_base_components( if has_text_encoder_4: if allow_override_text_encoder_4: - # prior model + # text encoder 4 weight dtype components.label(self.scroll_frame, row, 0, "Text Encoder 4 Override", tooltip="Filename, directory or Hugging Face repository of the text encoder 4 model") components.file_entry( diff --git a/modules/ui/TrainingTab.py b/modules/ui/TrainingTab.py index 3e366e919..f2fa67eb1 100644 --- a/modules/ui/TrainingTab.py +++ b/modules/ui/TrainingTab.py @@ -192,7 +192,7 @@ def __setup_pixart_alpha_ui(self, column_0, column_1, column_2): self.__create_embedding_frame(column_0, 2) self.__create_base2_frame(column_1, 0) - self.__create_prior_frame(column_1, 1) + self.__create_transformer_frame(column_1, 1) self.__create_noise_frame(column_1, 2) self.__create_masked_frame(column_2, 1) @@ -245,7 +245,7 @@ def __setup_sana_ui(self, column_0, column_1, column_2): self.__create_embedding_frame(column_0, 2) self.__create_base2_frame(column_1, 0) - self.__create_prior_frame(column_1, 1) + self.__create_transformer_frame(column_1, 1) self.__create_noise_frame(column_1, 2) self.__create_masked_frame(column_2, 1) @@ -597,30 +597,30 @@ def __create_transformer_frame(self, master, row, supports_guidance_scale: bool # train transformer components.label(frame, 0, 0, "Train Transformer", tooltip="Enables training the Transformer model") - components.switch(frame, 0, 1, self.ui_state, "prior.train") + components.switch(frame, 0, 1, self.ui_state, "transformer.train") # train transformer epochs components.label(frame, 1, 0, "Stop Training After", tooltip="When to stop training the Transformer") - components.time_entry(frame, 1, 1, self.ui_state, "prior.stop_training_after", "prior.stop_training_after_unit", + components.time_entry(frame, 1, 1, self.ui_state, "transformer.stop_training_after", "transformer.stop_training_after_unit", supports_time_units=False) # transformer learning rate components.label(frame, 2, 0, "Transformer Learning Rate", tooltip="The learning rate of the Transformer. Overrides the base learning rate") - components.entry(frame, 2, 1, self.ui_state, "prior.learning_rate") + components.entry(frame, 2, 1, self.ui_state, "transformer.learning_rate") if supports_force_attention_mask: # transformer learning rate components.label(frame, 3, 0, "Force Attention Mask", tooltip="Force enables passing of a text embedding attention mask to the transformer. This can improve training on shorter captions.") - components.switch(frame, 3, 1, self.ui_state, "prior.attention_mask") + components.switch(frame, 3, 1, self.ui_state, "transformer.attention_mask") if supports_guidance_scale: # guidance scale components.label(frame, 4, 0, "Guidance Scale", tooltip="The guidance scale of guidance distilled models passed to the transformer during training.") - components.entry(frame, 4, 1, self.ui_state, "prior.guidance_scale") + components.entry(frame, 4, 1, self.ui_state, "transformer.guidance_scale") def __create_noise_frame(self, master, row, supports_generalized_offset_noise: bool = False, supports_dynamic_timestep_shifting: bool = False): frame = ctk.CTkFrame(master=master, corner_radius=5) diff --git a/modules/util/ModelNames.py b/modules/util/ModelNames.py index 706544402..8dec9a9bf 100644 --- a/modules/util/ModelNames.py +++ b/modules/util/ModelNames.py @@ -13,6 +13,7 @@ def __init__( self, base_model: str = "", prior_model: str = "", + transformer_model: str = "", effnet_encoder_model: str = "", decoder_model: str = "", text_encoder_4: str = "", @@ -27,6 +28,7 @@ def __init__( ): self.base_model = base_model self.prior_model = prior_model + self.transformer_model = transformer_model self.effnet_encoder_model = effnet_encoder_model self.decoder_model = decoder_model self.text_encoder_4 = text_encoder_4 diff --git a/modules/util/ModelWeightDtypes.py b/modules/util/ModelWeightDtypes.py index e3ac7959f..3893b3a5f 100644 --- a/modules/util/ModelWeightDtypes.py +++ b/modules/util/ModelWeightDtypes.py @@ -10,6 +10,7 @@ def __init__( fallback_train_dtype: DataType, unet: DataType, prior: DataType, + transformer: DataType, text_encoder: DataType, text_encoder_2: DataType, text_encoder_3: DataType, @@ -27,6 +28,7 @@ def __init__( self.unet = unet self.prior = prior + self.transformer = transformer self.text_encoder = text_encoder self.text_encoder_2 = text_encoder_2 self.text_encoder_3 = text_encoder_3 @@ -43,6 +45,7 @@ def all_dtypes(self) -> list: return [ self.unet, self.prior, + self.transformer, self.text_encoder, self.text_encoder_2, self.text_encoder_3, diff --git a/modules/util/config/SampleConfig.py b/modules/util/config/SampleConfig.py index 362cd1c54..41e5fabe3 100644 --- a/modules/util/config/SampleConfig.py +++ b/modules/util/config/SampleConfig.py @@ -22,7 +22,7 @@ class SampleConfig(BaseConfig): text_encoder_2_layer_skip: int text_encoder_3_layer_skip: int text_encoder_4_layer_skip: int - prior_attention_mask: bool + transformer_attention_mask: bool force_last_timestep: bool sample_inpainting: bool @@ -37,7 +37,7 @@ def from_train_config(self, train_config): self.text_encoder_2_layer_skip = train_config.text_encoder_2_layer_skip self.text_encoder_3_layer_skip = train_config.text_encoder_3_layer_skip self.text_encoder_4_layer_skip = train_config.text_encoder_4_layer_skip - self.prior_attention_mask = train_config.prior.attention_mask + self.transformer_attention_mask = train_config.transformer.attention_mask self.force_last_timestep = train_config.rescale_noise_scheduler_to_zero_terminal_snr @staticmethod @@ -61,7 +61,7 @@ def default_values(): data.append(("text_encoder_2_layer_skip", 0, int, False)) data.append(("text_encoder_3_layer_skip", 0, int, False)) data.append(("text_encoder_4_layer_skip", 0, int, False)) - data.append(("prior_attention_mask", False, bool, False)) + data.append(("transformer_attention_mask", False, bool, False)) data.append(("force_last_timestep", False, bool, False)) data.append(("sample_inpainting", False, bool, False)) diff --git a/modules/util/config/TrainConfig.py b/modules/util/config/TrainConfig.py index ae210a0c4..a84fa4e44 100644 --- a/modules/util/config/TrainConfig.py +++ b/modules/util/config/TrainConfig.py @@ -404,6 +404,9 @@ class TrainConfig(BaseConfig): # prior prior: TrainModelPartConfig + # transformer + transformer: TrainModelPartConfig + # text encoder text_encoder: TrainModelPartConfig text_encoder_layer_skip: int @@ -499,7 +502,7 @@ class TrainConfig(BaseConfig): def __init__(self, data: list[(str, Any, type, bool)]): super().__init__( data, - config_version=8, + config_version=9, config_migrations={ 0: self.__migration_0, 1: self.__migration_1, @@ -509,6 +512,7 @@ def __init__(self, data: list[(str, Any, type, bool)]): 5: self.__migration_5, 6: self.__migration_6, 7: self.__migration_7, + 8: self.__migration_8, } ) @@ -698,12 +702,21 @@ def __migration_7(self, data: dict) -> dict: return migrated_data + def __migration_8(self, data: dict) -> dict: + migrated_data = data.copy() + + if migrated_data["model_type"] != "STABLE_CASCADE_1" and migrated_data["model_type"] != "WUERSTCHEN_2": + migrated_data["transformer"] = migrated_data["prior"] + + return migrated_data + def weight_dtypes(self) -> ModelWeightDtypes: return ModelWeightDtypes( self.train_dtype, self.fallback_train_dtype, self.weight_dtype if self.unet.weight_dtype == DataType.NONE else self.unet.weight_dtype, self.weight_dtype if self.prior.weight_dtype == DataType.NONE else self.prior.weight_dtype, + self.weight_dtype if self.transformer.weight_dtype == DataType.NONE else self.transformer.weight_dtype, self.weight_dtype if self.text_encoder.weight_dtype == DataType.NONE else self.text_encoder.weight_dtype, self.weight_dtype if self.text_encoder_2.weight_dtype == DataType.NONE else self.text_encoder_2.weight_dtype, self.weight_dtype if self.text_encoder_3.weight_dtype == DataType.NONE else self.text_encoder_3.weight_dtype, @@ -721,6 +734,7 @@ def model_names(self) -> ModelNames: return ModelNames( base_model=self.base_model_name, prior_model=self.prior.model_name, + transformer_model=self.transformer.model_name, effnet_encoder_model=self.effnet_encoder.model_name, decoder_model=self.decoder.model_name, text_encoder_4=self.text_encoder_4.model_name, @@ -946,6 +960,15 @@ def default_values() -> 'TrainConfig': prior.weight_dtype = DataType.NONE data.append(("prior", prior, TrainModelPartConfig, False)) + # prior + transformer = TrainModelPartConfig.default_values() + transformer.model_name = "" + transformer.train = True + transformer.stop_training_after = 0 + transformer.learning_rate = None + transformer.weight_dtype = DataType.NONE + data.append(("transformer", transformer, TrainModelPartConfig, False)) + # text encoder text_encoder = TrainModelPartConfig.default_values() text_encoder.train = True diff --git a/training_presets/#chroma Finetune 16GB.json b/training_presets/#chroma Finetune 16GB.json index b0bab8cfa..a25a5b564 100644 --- a/training_presets/#chroma Finetune 16GB.json +++ b/training_presets/#chroma Finetune 16GB.json @@ -7,7 +7,7 @@ "gradient_checkpointing": "CPU_OFFLOADED", "layer_offload_fraction": 0.4, "dataloader_threads": 1, - "prior": { + "transformer": { "train": true, "weight_dtype": "BFLOAT_16" }, diff --git a/training_presets/#chroma Finetune 24GB.json b/training_presets/#chroma Finetune 24GB.json index 65a44b09c..a57648fbf 100644 --- a/training_presets/#chroma Finetune 24GB.json +++ b/training_presets/#chroma Finetune 24GB.json @@ -4,7 +4,7 @@ "learning_rate": 1e-5, "model_type": "CHROMA_1", "resolution": "512", - "prior": { + "transformer": { "train": true, "weight_dtype": "BFLOAT_16" }, diff --git a/training_presets/#chroma Finetune 8GB.json b/training_presets/#chroma Finetune 8GB.json index 70db75cf9..4768ac714 100644 --- a/training_presets/#chroma Finetune 8GB.json +++ b/training_presets/#chroma Finetune 8GB.json @@ -7,7 +7,7 @@ "gradient_checkpointing": "CPU_OFFLOADED", "layer_offload_fraction": 0.85, "dataloader_threads": 1, - "prior": { + "transformer": { "train": true, "weight_dtype": "BFLOAT_16" }, diff --git a/training_presets/#chroma LoRA 16GB.json b/training_presets/#chroma LoRA 16GB.json index d718eb75e..a99eecca4 100644 --- a/training_presets/#chroma LoRA 16GB.json +++ b/training_presets/#chroma LoRA 16GB.json @@ -4,7 +4,7 @@ "learning_rate": 0.0003, "model_type": "CHROMA_1", "resolution": "512", - "prior": { + "transformer": { "train": true, "weight_dtype": "FLOAT_8" }, diff --git a/training_presets/#chroma LoRA 24GB.json b/training_presets/#chroma LoRA 24GB.json index bfdaccd4c..5877009c6 100644 --- a/training_presets/#chroma LoRA 24GB.json +++ b/training_presets/#chroma LoRA 24GB.json @@ -4,7 +4,7 @@ "learning_rate": 0.0003, "model_type": "CHROMA_1", "resolution": "512", - "prior": { + "transformer": { "train": true, "weight_dtype": "BFLOAT_16" }, diff --git a/training_presets/#chroma LoRA 8GB.json b/training_presets/#chroma LoRA 8GB.json index ded17978f..6fa19670c 100644 --- a/training_presets/#chroma LoRA 8GB.json +++ b/training_presets/#chroma LoRA 8GB.json @@ -7,7 +7,7 @@ "gradient_checkpointing": "CPU_OFFLOADED", "layer_offload_fraction": 0.6, "dataloader_threads": 1, - "prior": { + "transformer": { "train": true, "weight_dtype": "FLOAT_8" }, diff --git a/training_presets/#flux LoRA.json b/training_presets/#flux LoRA.json index 517d1113e..c4456024c 100644 --- a/training_presets/#flux LoRA.json +++ b/training_presets/#flux LoRA.json @@ -7,7 +7,7 @@ "output_model_destination": "models/lora.safetensors", "output_model_format": "SAFETENSORS", "resolution": "768", - "prior": { + "transformer": { "train": true, "weight_dtype": "NFLOAT_4" }, diff --git a/training_presets/#qwen Finetune 16GB.json b/training_presets/#qwen Finetune 16GB.json index ea29ea490..811d7e0b1 100644 --- a/training_presets/#qwen Finetune 16GB.json +++ b/training_presets/#qwen Finetune 16GB.json @@ -7,7 +7,7 @@ "gradient_checkpointing": "CPU_OFFLOADED", "layer_offload_fraction": 0.75, "dataloader_threads": 1, - "prior": { + "transformer": { "train": true, "weight_dtype": "BFLOAT_16" }, diff --git a/training_presets/#qwen Finetune 24GB.json b/training_presets/#qwen Finetune 24GB.json index a425e07f8..8bee3cd3f 100644 --- a/training_presets/#qwen Finetune 24GB.json +++ b/training_presets/#qwen Finetune 24GB.json @@ -7,7 +7,7 @@ "gradient_checkpointing": "CPU_OFFLOADED", "layer_offload_fraction": 0.55, "dataloader_threads": 1, - "prior": { + "transformer": { "train": true, "weight_dtype": "BFLOAT_16" }, diff --git a/training_presets/#qwen LoRA 16GB.json b/training_presets/#qwen LoRA 16GB.json index eae8bad72..a101c788c 100644 --- a/training_presets/#qwen LoRA 16GB.json +++ b/training_presets/#qwen LoRA 16GB.json @@ -7,7 +7,7 @@ "gradient_checkpointing": "CPU_OFFLOADED", "layer_offload_fraction": 0.5, "dataloader_threads": 1, - "prior": { + "transformer": { "train": true, "weight_dtype": "FLOAT_8" }, diff --git a/training_presets/#qwen LoRA 24GB.json b/training_presets/#qwen LoRA 24GB.json index 6b76111e2..bd03b0cc2 100644 --- a/training_presets/#qwen LoRA 24GB.json +++ b/training_presets/#qwen LoRA 24GB.json @@ -7,7 +7,7 @@ "gradient_checkpointing": "CPU_OFFLOADED", "layer_offload_fraction": 0.1, "dataloader_threads": 1, - "prior": { + "transformer": { "train": true, "weight_dtype": "FLOAT_8" }, diff --git a/training_presets/#sd 3.json b/training_presets/#sd 3.json index 597ff0704..a741fc574 100644 --- a/training_presets/#sd 3.json +++ b/training_presets/#sd 3.json @@ -4,7 +4,7 @@ "output_dtype": "FLOAT_16", "output_model_destination": "models/model.safetensors", "output_model_format": "SAFETENSORS", - "prior": { + "transformer": { "weight_dtype": "FLOAT_32" }, "resolution": "1024", From 71af1f0ceec30ade9b2547188baa4872d75ad429 Mon Sep 17 00:00:00 2001 From: dxqb Date: Thu, 16 Oct 2025 16:32:35 +0200 Subject: [PATCH 21/54] ui fix --- modules/ui/ModelTab.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index 76965cdd1..e478e6ad6 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -290,6 +290,8 @@ def __create_base_dtype_components(self, row: int) -> int: wide_tooltip=True) components.entry(self.scroll_frame, row, 1, self.ui_state, "secrets.huggingface_token") + row += 1 + # base model components.label(self.scroll_frame, row, 0, "Base Model", tooltip="Filename, directory or Hugging Face repository of the base model") From 0f58a5e8f1fa74bf9f5332e339be4e0cfc990cf5 Mon Sep 17 00:00:00 2001 From: dxqb Date: Thu, 16 Oct 2025 18:19:38 +0200 Subject: [PATCH 22/54] GGUF with DoRA --- modules/module/LoRAModule.py | 1 - modules/util/quantization_util.py | 23 +++++++++++------------ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/modules/module/LoRAModule.py b/modules/module/LoRAModule.py index c9c24777b..c9d91cf14 100644 --- a/modules/module/LoRAModule.py +++ b/modules/module/LoRAModule.py @@ -385,7 +385,6 @@ def check_initialized(self): def forward(self, x, *args, **kwargs): self.check_initialized() - A = self.lora_down.weight B = self.lora_up.weight orig_weight = get_unquantized_weight(self.orig_module, A.dtype, self.train_device) diff --git a/modules/util/quantization_util.py b/modules/util/quantization_util.py index 5ba318e20..c860a4a6e 100644 --- a/modules/util/quantization_util.py +++ b/modules/util/quantization_util.py @@ -8,6 +8,8 @@ import torch from torch import Tensor, nn +from diffusers.quantizers.gguf.utils import GGUFLinear, dequantize_gguf_tensor + try: from modules.module.quantized.LinearNf4 import LinearNf4 @@ -196,22 +198,19 @@ def quantize_layers(module: nn.Module, device: torch.device, train_dtype: DataTy child_module.quantize(device) -def get_unquantized_weight(module: nn.Module, dtype: torch.dtype, device: torch.device) -> Tensor: +def get_unquantized_weight(module: nn.Linear, dtype: torch.dtype, device: torch.device) -> Tensor: + assert isinstance(module, nn.Linear) if isinstance(module, QuantizedLinearMixin): return module.unquantized_weight(dtype, device) + elif isinstance(module, GGUFLinear): + return dequantize_gguf_tensor(module.weight).to(dtype=dtype) + else: + return module.weight.detach().to(dtype=dtype) - return module.weight.detach().to(dtype=dtype) - - -def get_weight_shape(module: nn.Module) -> torch.Size: - param = module.weight - - if bnb is not None: - if isinstance(module, LinearNf4): - return module.shape - - return param.shape +def get_weight_shape(module: nn.Linear) -> torch.Size: + assert isinstance(module, nn.Linear) + return torch.Size((module.out_features, module.in_features)) def get_offload_tensors(module: nn.Module) -> list[torch.Tensor]: tensors = [] From 881e7e56998b77021959751ed714c8d6d8248608 Mon Sep 17 00:00:00 2001 From: dxqb Date: Fri, 17 Oct 2025 11:01:17 +0200 Subject: [PATCH 23/54] GGUF A8 float bugfix --- modules/module/quantized/LinearGGUFA8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/module/quantized/LinearGGUFA8.py b/modules/module/quantized/LinearGGUFA8.py index 067cecabf..d75003481 100644 --- a/modules/module/quantized/LinearGGUFA8.py +++ b/modules/module/quantized/LinearGGUFA8.py @@ -19,7 +19,7 @@ def int8_forward_both_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> res_scaled.add_(bias.to(x.dtype)) return res_scaled -def fp8_forward_both_axiswise(x: Tensor, weight: float | Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: +def fp8_forward_both_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> Tensor: x_8, x_scale = quantize_fp8_axiswise(x, dim=-1) w_8, w_scale = quantize_fp8_axiswise(weight, dim=-1) one = torch.ones(1, device=x.device) From 25ccc0c661d248a6bf23dd0ba0dfe2e117f138ee Mon Sep 17 00:00:00 2001 From: dxqb Date: Fri, 17 Oct 2025 12:02:07 +0200 Subject: [PATCH 24/54] improve check for #1050 --- modules/util/quantization_util.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/modules/util/quantization_util.py b/modules/util/quantization_util.py index 5ba318e20..7ffeea2ae 100644 --- a/modules/util/quantization_util.py +++ b/modules/util/quantization_util.py @@ -69,7 +69,7 @@ def __create_fp8_linear_layer(module: nn.Linear, copy_parameters: bool) -> nn.Mo return quant_linear -def __replace_linear_layers( +def __replace_linear_layers_recursive( parent_module: nn.Module, convert_fn: Callable[[nn.Linear, bool], nn.Module], keep_in_fp32_modules: list[str] | None = None, @@ -85,7 +85,6 @@ def __replace_linear_layers( visited_modules = set() visited_modules.add(id(parent_module)) - if isinstance(parent_module, (nn.ModuleList, nn.Sequential)): for i, module in enumerate(parent_module): if isinstance(module, nn.Linear): @@ -93,7 +92,7 @@ def __replace_linear_layers( parent_module[i] = quant_linear del module elif id(module) not in visited_modules: - __replace_linear_layers( + __replace_linear_layers_recursive( parent_module=module, convert_fn=convert_fn, keep_in_fp32_modules=keep_in_fp32_modules, @@ -112,7 +111,7 @@ def __replace_linear_layers( setattr(parent_module, attr_name, quant_linear) del module elif isinstance(module, nn.Module) and id(module) not in visited_modules: - __replace_linear_layers( + __replace_linear_layers_recursive( parent_module=module, convert_fn=convert_fn, keep_in_fp32_modules=keep_in_fp32_modules, @@ -121,10 +120,21 @@ def __replace_linear_layers( visited_modules=visited_modules, ) +def __replace_linear_layers( + parent_module: nn.Module, + convert_fn: Callable[[nn.Linear, bool], nn.Module], + keep_in_fp32_modules: list[str] | None = None, + copy_parameters: bool = False, +): + __replace_linear_layers_recursive(parent_module, convert_fn, keep_in_fp32_modules, copy_parameters) + + #ensure that all Linear layers were replaced + #https://github.com/Nerogar/OneTrainer/issues/1050 for name, module in parent_module.named_modules(): - #ensure that all Linear layers were replaced - #https://github.com/Nerogar/OneTrainer/issues/1050 - assert not isinstance(module, nn.Linear) or isinstance(module, QuantizedLinearMixin), f"Linear layer {name} was not found in model for quantization" + assert (not isinstance(module, nn.Linear) + or isinstance(module, QuantizedLinearMixin) + or any(s in name.split('.') for s in keep_in_fp32_modules) + ), f"Linear layer {name} was not found in model for quantization" def replace_linear_with_nf4_layers( parent_module: nn.Module, From b4d8f30116ddc2cd45e0a38bf3f2404b80c750c5 Mon Sep 17 00:00:00 2001 From: dxqb Date: Fri, 17 Oct 2025 12:08:32 +0200 Subject: [PATCH 25/54] improve check for #1050 --- modules/util/quantization_util.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/modules/util/quantization_util.py b/modules/util/quantization_util.py index 6a92f0756..963fb47f0 100644 --- a/modules/util/quantization_util.py +++ b/modules/util/quantization_util.py @@ -40,7 +40,7 @@ def __create_linear_layer(construct_fn, module: nn.Linear, copy_parameters: bool return quant_linear -def __replace_linear_layers_recursive( +def __replace_linear_layers( parent_module: nn.Module, construct_fn, keep_in_fp32_modules: list[str] | None = None, @@ -63,7 +63,7 @@ def __replace_linear_layers_recursive( parent_module[i] = quant_linear del module elif id(module) not in visited_modules: - __replace_linear_layers_recursive( + __replace_linear_layers( parent_module=module, construct_fn=construct_fn, keep_in_fp32_modules=keep_in_fp32_modules, @@ -82,7 +82,7 @@ def __replace_linear_layers_recursive( setattr(parent_module, attr_name, quant_linear) del module elif isinstance(module, nn.Module) and id(module) not in visited_modules: - __replace_linear_layers_recursive( + __replace_linear_layers( parent_module=module, construct_fn=construct_fn, keep_in_fp32_modules=keep_in_fp32_modules, @@ -91,22 +91,6 @@ def __replace_linear_layers_recursive( visited_modules=visited_modules, ) -def __replace_linear_layers( - parent_module: nn.Module, - convert_fn: Callable[[nn.Linear, bool], nn.Module], - keep_in_fp32_modules: list[str] | None = None, - copy_parameters: bool = False, -): - __replace_linear_layers_recursive(parent_module, convert_fn, keep_in_fp32_modules, copy_parameters) - - #ensure that all Linear layers were replaced - #https://github.com/Nerogar/OneTrainer/issues/1050 - for name, module in parent_module.named_modules(): - assert (not isinstance(module, nn.Linear) - or isinstance(module, QuantizedLinearMixin) - or any(s in name.split('.') for s in keep_in_fp32_modules) - ), f"Linear layer {name} was not found in model for quantization" - def replace_linear_with_quantized_layers( parent_module: nn.Module, dtype: DataType, @@ -133,6 +117,14 @@ def replace_linear_with_quantized_layers( copy_parameters=copy_parameters, ) + #ensure that all Linear layers were replaced + #https://github.com/Nerogar/OneTrainer/issues/1050 + for name, module in parent_module.named_modules(): + assert (not isinstance(module, nn.Linear) + or isinstance(module, QuantizedLinearMixin) + or any(s in name.split('.') for s in keep_in_fp32_modules) + ), f"Linear layer {name} was not found in model for quantization" + def is_quantized_parameter( module: nn.Module, From da296d61f3e0ad1ecd25d7090d2bb2535a9f9d39 Mon Sep 17 00:00:00 2001 From: dxqb Date: Fri, 17 Oct 2025 12:16:42 +0200 Subject: [PATCH 26/54] re-enabled int W8A8 --- modules/ui/ModelTab.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index e478e6ad6..4d38c978b 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -264,7 +264,7 @@ def __create_dtype_options(self, include_none: bool=True, include_svd: bool=Fals ("float16", DataType.FLOAT_16), ("float8 (W8)", DataType.FLOAT_8), ("float W8A8", DataType.FLOAT_W8A8), - #("int W8A8", DataType.INT_W8A8), #not recommended + ("int W8A8", DataType.INT_W8A8), # ("int8", DataType.INT_8), # TODO: reactivate when the int8 implementation is fixed in bitsandbytes: https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1332 ("nfloat4", DataType.NFLOAT_4), ] From 56902b6c750c6f9cdf08b0207b2b277f64a688db Mon Sep 17 00:00:00 2001 From: dxqb Date: Tue, 28 Oct 2025 19:50:22 +0100 Subject: [PATCH 27/54] only quantize activations if GGUF weights are actually quantized --- modules/module/quantized/LinearGGUFA8.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/modules/module/quantized/LinearGGUFA8.py b/modules/module/quantized/LinearGGUFA8.py index d75003481..1898c2915 100644 --- a/modules/module/quantized/LinearGGUFA8.py +++ b/modules/module/quantized/LinearGGUFA8.py @@ -9,6 +9,10 @@ from diffusers.quantizers.gguf.utils import GGUFLinear, dequantize_gguf_tensor +import gguf + +UNQUANTIZED_TYPES = [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16] + def int8_forward_both_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> Tensor: x_8, x_scale = quantize_int8_axiswise(x, dim=-1) @@ -69,28 +73,28 @@ def backward(ctx, x: Tensor): weight, = ctx.saved_tensors return fp8_backward_both_axiswise(x, weight), None, None - class LinearGGUFA8(GGUFLinear): - def __init__(self, dtype, compute_dtype, *args, **kwargs): + def __init__(self, dtype: torch.dtype, *args, **kwargs): super().__init__(*args, **kwargs) assert dtype in [torch.int8, torch.float8_e4m3fn] self._dtype = dtype - self._compute_dtype = compute_dtype - def forward(self, x_orig: torch.Tensor) -> torch.Tensor: assert not self.weight.requires_grad - x = x_orig.to(self._compute_dtype).reshape(-1, x_orig.shape[-1]) + x = x_orig.to(self.compute_dtype).reshape(-1, x_orig.shape[-1]) w = dequantize_gguf_tensor(self.weight) - if x.shape[0] > 16: + if x.shape[0] > 16 and self.weight.quant_type not in UNQUANTIZED_TYPES: if self._dtype == torch.int8: y = LinearGGUFIntA8RequantFunction.apply(x, w, self.bias) else: y = LinearGGUFFpA8RequantFunction.apply(x, w, self.bias) else: - y = torch.nn.functional.linear(x, w, self.bias.to(self._compute_dtype)) + x = x.to(self.compute_dtype) + w = w.to(self.compute_dtype) + bias = self.bias.to(self.compute_dtype) if self.bias is not None else None + y = torch.nn.functional.linear(x, w, bias) - assert y.dtype == self._compute_dtype + assert y.dtype == self.compute_dtype return y.reshape(x_orig.shape[:-1] + (y.shape[-1], )) From eaf4fe2c1cd3560dce489657b8d7ffb58f6b8147 Mon Sep 17 00:00:00 2001 From: dxqb Date: Sun, 2 Nov 2025 11:25:10 +0100 Subject: [PATCH 28/54] make layer filter a component --- modules/ui/TrainingTab.py | 149 ++++++++-------------------------- modules/util/ui/components.py | 76 +++++++++++++++++ 2 files changed, 109 insertions(+), 116 deletions(-) diff --git a/modules/ui/TrainingTab.py b/modules/ui/TrainingTab.py index ab8b44520..b5d7a2079 100644 --- a/modules/ui/TrainingTab.py +++ b/modules/ui/TrainingTab.py @@ -42,18 +42,6 @@ def __init__(self, master, train_config: TrainConfig, ui_state: UIState): master.grid_rowconfigure(0, weight=1) master.grid_columnconfigure(0, weight=1) - #layer filter: - self.layer_entry = None - self.layer_entry_fg_color = None - self.layer_entry_text_color = None - self.layer_selector = None - self.regex_label = None - self.regex_switch = None - self.presets = {} - self.presets_list = [] - self.prior_custom = "" - self.prior_selected = None - self.scroll_frame = None self.refresh_ui() @@ -81,32 +69,6 @@ def refresh_ui(self): column_2.grid(row=0, column=2, sticky="nsew") column_2.grid_columnconfigure(0, weight=1) - if self.train_config.model_type.is_stable_diffusion(): #TODO simplify - self.presets = sd_presets - elif self.train_config.model_type.is_stable_diffusion_xl(): - self.presets = sdxl_presets - elif self.train_config.model_type.is_stable_diffusion_3(): - self.presets = sd3_presets - elif self.train_config.model_type.is_wuerstchen(): - self.presets = sc_presets - elif self.train_config.model_type.is_pixart(): - self.presets = pixart_presets - elif self.train_config.model_type.is_flux(): - self.presets = flux_presets - elif self.train_config.model_type.is_qwen(): - self.presets = qwen_presets - elif self.train_config.model_type.is_chroma(): - self.presets = chroma_presets - elif self.train_config.model_type.is_sana(): - self.presets = sana_presets - elif self.train_config.model_type.is_hunyuan_video(): - self.presets = hunyuan_video_presets - elif self.train_config.model_type.is_hi_dream(): - self.presets = hidream_presets - else: - self.presets = {"full": []} - self.presets_list = list(self.presets.keys()) + ["custom"] - if self.train_config.model_type.is_stable_diffusion(): self.__setup_stable_diffusion_ui(column_0, column_1, column_2) if self.train_config.model_type.is_stable_diffusion_3(): @@ -779,85 +741,40 @@ def __create_loss_frame(self, master, row, supports_vb_loss: bool = False): components.options(frame, 8, 1, [str(x) for x in list(LossScaler)], self.ui_state, "loss_scaler") def __create_layer_frame(self, master, row): - frame = ctk.CTkFrame(master=master, corner_radius=5) - frame.grid(row=row, column=0, padx=5, pady=5, sticky="nsew") - frame.grid_columnconfigure(0, weight=1) - - components.label(frame, 0, 0, "Layer Filter", - tooltip="Select a preset defining which layers to train, or select 'Custom' to define your own. A blank custom field will train all layers.") - self.layer_selector = components.options( - frame, 0, 1, self.presets_list, self.ui_state, "layer_filter_preset", - command=self.__preset_set_layer_choice - ) - - self.layer_entry = components.entry( - frame, 1, 0, self.ui_state, "layer_filter", - tooltip="Comma-separated list of diffusion layers to train. Regular expressions (if toggled) are supported. Any model layer with a matching name will be trained" - ) - self.layer_entry_fg_color = self.layer_entry.cget("fg_color") - self.layer_entry_text_color = self.layer_entry.cget("text_color") - - self.regex_label = components.label( - frame, 2, 0, "Use Regex", - tooltip="If enabled, layer filter patterns are interpreted as regular expressions. Otherwise, simple substring matching is used." - ) - self.regex_switch = components.switch( - frame, 2, 1, self.ui_state, "layer_filter_regex" - ) - - # Let the user set their own layer filter - if self.train_config.layer_filter and self.train_config.layer_filter_preset == "custom": - self.prior_custom = self.train_config.layer_filter - else: - self.prior_custom = "" - - self.layer_entry.grid_configure(columnspan=2, sticky="ew") - # Some configs will come with the layer_filter_preset unset or wrong for - # the new model, so let's set it now to a reasonable default so it hits - # the UI correctly. - if self.layer_selector.get() not in self.presets_list: - self.layer_selector.set(self.presets_list[0]) - self.__preset_set_layer_choice(self.layer_selector.get()) - - - def __preset_set_layer_choice(self, selected: str): - if not selected: - selected = self.presets_list[0] - - if selected == "custom": - # Restore prior custom text and allow editing + regex toggle - self.layer_entry.configure(state="normal", fg_color=self.layer_entry_fg_color, text_color=self.layer_entry_text_color) - self.layer_entry.cget('textvariable').set(self.prior_custom) - self.regex_label.grid() - self.regex_switch.grid() + presets = [] + if self.train_config.model_type.is_stable_diffusion(): #TODO simplify + presets = sd_presets + elif self.train_config.model_type.is_stable_diffusion_xl(): + presets = sdxl_presets + elif self.train_config.model_type.is_stable_diffusion_3(): + presets = sd3_presets + elif self.train_config.model_type.is_wuerstchen(): + presets = sc_presets + elif self.train_config.model_type.is_pixart(): + presets = pixart_presets + elif self.train_config.model_type.is_flux(): + presets = flux_presets + elif self.train_config.model_type.is_qwen(): + presets = qwen_presets + elif self.train_config.model_type.is_chroma(): + presets = chroma_presets + elif self.train_config.model_type.is_sana(): + presets = sana_presets + elif self.train_config.model_type.is_hunyuan_video(): + presets = hunyuan_video_presets + elif self.train_config.model_type.is_hi_dream(): + presets = hidream_presets else: - # Preserve custom text before overwriting - if self.prior_selected == "custom": - self.prior_custom = self.layer_entry.get() - - # Resolve preset definition (list[str] OR {'patterns': [...], 'regex': bool}) - preset_def = self.presets.get(selected, []) - if isinstance(preset_def, dict): - patterns = preset_def.get("patterns", []) - preset_uses_regex = bool(preset_def.get("regex", False)) - else: - patterns = preset_def - preset_uses_regex = False - - disabled_color = ("gray85", "gray17") - disabled_text_color = ("gray30", "gray70") - self.layer_entry.configure(state="disabled", fg_color=disabled_color, text_color=disabled_text_color) - self.layer_entry.cget('textvariable').set(",".join(patterns)) - - self.train_config.layer_filter = ",".join(patterns) - - self.train_config.layer_filter_regex_regex = preset_uses_regex - self.ui_state.get_var("layer_filter_regex").set(preset_uses_regex) - - self.regex_label.grid_remove() - self.regex_switch.grid_remove() - - self.prior_selected = selected + presets = {"full": []} + components.layer_filter_entry(master, row, 0, self.ui_state, + preset_var_name="layer_filter_preset", presets=presets, + preset_label="Layer Filter", + preset_tooltip="Select a preset defining which layers to train, or select 'Custom' to define your own. A blank custom field will train all layers.", + entry_var_name="layer_filter", + entry_tooltip="Comma-separated list of diffusion layers to train. Regular expressions (if toggled) are supported. Any model layer with a matching name will be trained", + regex_var_name="layer_filter_regex", + regex_tooltip="If enabled, layer filter patterns are interpreted as regular expressions. Otherwise, simple substring matching is used.", + ) def __open_optimizer_params_window(self): window = OptimizerParamsWindow(self.master, self.train_config, self.ui_state) diff --git a/modules/util/ui/components.py b/modules/util/ui/components.py index 3043f32fc..1c541a606 100644 --- a/modules/util/ui/components.py +++ b/modules/util/ui/components.py @@ -296,6 +296,82 @@ def time_entry(master, row, column, ui_state: UIState, var_name: str, unit_var_n return frame +def layer_filter_entry(master, row, column, ui_state: UIState, preset_var_name: str, preset_label: str, preset_tooltip: str, presets, entry_var_name, entry_tooltip: str, regex_var_name, regex_tooltip: str): + frame = ctk.CTkFrame(master=master, corner_radius=5) + frame.grid(row=row, column=0, padx=5, pady=5, sticky="nsew") + frame.grid_columnconfigure(0, weight=1) + + layer_entry = entry( + frame, 1, 0, ui_state, entry_var_name, + tooltip=entry_tooltip + ) + layer_entry_fg_color = layer_entry.cget("fg_color") + layer_entry_text_color = layer_entry.cget("text_color") + + regex_label = label( + frame, 2, 0, "Use Regex", + tooltip=regex_tooltip, + ) + regex_switch = switch( + frame, 2, 1, ui_state, regex_var_name + ) + + # Let the user set their own layer filter + # TODO + #if self.train_config.layer_filter and self.train_config.layer_filter_preset == "custom": + # self.prior_custom = self.train_config.layer_filter + #else: + # self.prior_custom = "" + + layer_entry.grid_configure(columnspan=2, sticky="ew") + + presets_list = list(presets.keys()) + ["custom"] + + + def preset_set_layer_choice(selected: str): + if not selected or selected not in presets_list: + selected = presets_list[0] + + if selected == "custom": + # Restore prior custom text and allow editing + regex toggle + layer_entry.configure(state="normal", fg_color=layer_entry_fg_color, text_color=layer_entry_text_color) + #layer_entry.cget('textvariable').set("") + regex_label.grid() + regex_switch.grid() + else: + # Preserve custom text before overwriting + #if self.prior_selected == "custom": + # self.prior_custom = self.layer_entry.get() + + # Resolve preset definition (list[str] OR {'patterns': [...], 'regex': bool}) + preset_def = presets.get(selected, []) + if isinstance(preset_def, dict): + patterns = preset_def.get("patterns", []) + preset_uses_regex = bool(preset_def.get("regex", False)) + else: + patterns = preset_def + preset_uses_regex = False + + disabled_color = ("gray85", "gray17") + disabled_text_color = ("gray30", "gray70") + layer_entry.configure(state="disabled", fg_color=disabled_color, text_color=disabled_text_color) + layer_entry.cget('textvariable').set(",".join(patterns)) + + ui_state.get_var("layer_filter").set(",".join(patterns)) + ui_state.get_var("layer_filter_regex").set(preset_uses_regex) + + regex_label.grid_remove() + regex_switch.grid_remove() + +# self.prior_selected = selected + + label(frame, 0, 0, preset_label, + tooltip=preset_tooltip) + layer_selector = options( + frame, 0, 1, presets_list, ui_state, preset_var_name, + command=preset_set_layer_choice + ) + preset_set_layer_choice(layer_selector.get()) def icon_button(master, row, column, text, command): component = ctk.CTkButton(master, text=text, width=40, command=command) From 867b84cd110855e58b8a579d9b1ea1fc6f85378c Mon Sep 17 00:00:00 2001 From: dxqb Date: Sun, 2 Nov 2025 14:47:55 +0100 Subject: [PATCH 29/54] quantization layer filter --- .../modelLoader/ChromaEmbeddingModelLoader.py | 4 +- .../modelLoader/ChromaFineTuneModelLoader.py | 4 +- modules/modelLoader/ChromaLoRAModelLoader.py | 4 +- .../modelLoader/FluxEmbeddingModelLoader.py | 4 +- .../modelLoader/FluxFineTuneModelLoader.py | 4 +- modules/modelLoader/FluxLoRAModelLoader.py | 4 +- .../HiDreamEmbeddingModelLoader.py | 4 +- .../modelLoader/HiDreamFineTuneModelLoader.py | 4 +- modules/modelLoader/HiDreamLoRAModelLoader.py | 4 +- .../HunyuanVideoEmbeddingModelLoader.py | 4 +- .../HunyuanVideoFineTuneModelLoader.py | 4 +- .../HunyuanVideoLoRAModelLoader.py | 4 +- .../PixArtAlphaEmbeddingModelLoader.py | 4 +- .../PixArtAlphaFineTuneModelLoader.py | 4 +- .../modelLoader/PixArtAlphaLoRAModelLoader.py | 4 +- .../modelLoader/QwenFineTuneModelLoader.py | 4 +- modules/modelLoader/QwenLoRAModelLoader.py | 4 +- .../modelLoader/SanaEmbeddingModelLoader.py | 4 +- .../modelLoader/SanaFineTuneModelLoader.py | 4 +- modules/modelLoader/SanaLoRAModelLoader.py | 4 +- .../StableDiffusion3EmbeddingModelLoader.py | 4 +- .../StableDiffusion3FineTuneModelLoader.py | 4 +- .../StableDiffusion3LoRAModelLoader.py | 4 +- .../StableDiffusionEmbeddingModelLoader.py | 4 +- .../StableDiffusionFineTuneModelLoader.py | 4 +- .../StableDiffusionLoRAModelLoader.py | 4 +- .../StableDiffusionXLEmbeddingModelLoader.py | 4 +- .../StableDiffusionXLFineTuneModelLoader.py | 4 +- .../StableDiffusionXLLoRAModelLoader.py | 4 +- .../WuerstchenEmbeddingModelLoader.py | 4 +- .../WuerstchenFineTuneModelLoader.py | 4 +- .../modelLoader/WuerstchenLoRAModelLoader.py | 4 +- .../modelLoader/chroma/ChromaModelLoader.py | 16 +- modules/modelLoader/flux/FluxModelLoader.py | 18 +- .../modelLoader/hiDream/HiDreamModelLoader.py | 16 +- .../hunyuanVideo/HunyuanVideoModelLoader.py | 16 +- .../modelLoader/mixin/HFModelLoaderMixin.py | 22 +- .../pixartAlpha/PixArtAlphaModelLoader.py | 11 +- modules/modelLoader/qwen/QwenModelLoader.py | 16 +- modules/modelLoader/sana/SanaModelLoader.py | 11 +- .../StableDiffusionModelLoader.py | 19 +- .../StableDiffusion3ModelLoader.py | 16 +- .../StableDiffusionXLModelLoader.py | 20 +- .../wuerstchen/WuerstchenModelLoader.py | 13 +- modules/trainer/GenericTrainer.py | 7 + modules/ui/ConvertModelUI.py | 2 + modules/ui/ModelTab.py | 277 ++++++++++++------ modules/ui/SampleWindow.py | 6 + modules/util/config/TrainConfig.py | 5 + modules/util/quantization_util.py | 26 +- 50 files changed, 460 insertions(+), 185 deletions(-) diff --git a/modules/modelLoader/ChromaEmbeddingModelLoader.py b/modules/modelLoader/ChromaEmbeddingModelLoader.py index ea03d5757..c113d0877 100644 --- a/modules/modelLoader/ChromaEmbeddingModelLoader.py +++ b/modules/modelLoader/ChromaEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class ChromaEmbeddingModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> ChromaModel | None: base_model_loader = ChromaModelLoader() embedding_loader = ChromaEmbeddingLoader() @@ -41,7 +43,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/ChromaFineTuneModelLoader.py b/modules/modelLoader/ChromaFineTuneModelLoader.py index c9bbb1891..0915d5509 100644 --- a/modules/modelLoader/ChromaFineTuneModelLoader.py +++ b/modules/modelLoader/ChromaFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class ChromaFineTuneModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> ChromaModel | None: base_model_loader = ChromaModelLoader() embedding_loader = ChromaEmbeddingLoader() @@ -41,7 +43,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/ChromaLoRAModelLoader.py b/modules/modelLoader/ChromaLoRAModelLoader.py index aaa38058c..37d07644a 100644 --- a/modules/modelLoader/ChromaLoRAModelLoader.py +++ b/modules/modelLoader/ChromaLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class ChromaLoRAModelLoader( @@ -33,6 +34,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> ChromaModel | None: base_model_loader = ChromaModelLoader() lora_model_loader = ChromaLoRALoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/FluxEmbeddingModelLoader.py b/modules/modelLoader/FluxEmbeddingModelLoader.py index 0960c051a..09baf3dbb 100644 --- a/modules/modelLoader/FluxEmbeddingModelLoader.py +++ b/modules/modelLoader/FluxEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class FluxEmbeddingModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> FluxModel | None: base_model_loader = FluxModelLoader() embedding_loader = FluxEmbeddingLoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/FluxFineTuneModelLoader.py b/modules/modelLoader/FluxFineTuneModelLoader.py index 4fac329a3..599cdfb10 100644 --- a/modules/modelLoader/FluxFineTuneModelLoader.py +++ b/modules/modelLoader/FluxFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class FluxFineTuneModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> FluxModel | None: base_model_loader = FluxModelLoader() embedding_loader = FluxEmbeddingLoader() @@ -43,7 +45,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/FluxLoRAModelLoader.py b/modules/modelLoader/FluxLoRAModelLoader.py index 84d9133dc..a7a554370 100644 --- a/modules/modelLoader/FluxLoRAModelLoader.py +++ b/modules/modelLoader/FluxLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class FluxLoRAModelLoader( @@ -35,6 +36,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> FluxModel | None: base_model_loader = FluxModelLoader() lora_model_loader = FluxLoRALoader() @@ -45,7 +47,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/HiDreamEmbeddingModelLoader.py b/modules/modelLoader/HiDreamEmbeddingModelLoader.py index 7539fdf70..4583db403 100644 --- a/modules/modelLoader/HiDreamEmbeddingModelLoader.py +++ b/modules/modelLoader/HiDreamEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class HiDreamEmbeddingModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> HiDreamModel | None: base_model_loader = HiDreamModelLoader() embedding_loader = HiDreamEmbeddingLoader() @@ -41,7 +43,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/HiDreamFineTuneModelLoader.py b/modules/modelLoader/HiDreamFineTuneModelLoader.py index d9de7fc7a..3d4b27cb4 100644 --- a/modules/modelLoader/HiDreamFineTuneModelLoader.py +++ b/modules/modelLoader/HiDreamFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class HiDreamFineTuneModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> HiDreamModel | None: base_model_loader = HiDreamModelLoader() embedding_loader = HiDreamEmbeddingLoader() @@ -41,7 +43,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/HiDreamLoRAModelLoader.py b/modules/modelLoader/HiDreamLoRAModelLoader.py index 524ab3e89..90898fc73 100644 --- a/modules/modelLoader/HiDreamLoRAModelLoader.py +++ b/modules/modelLoader/HiDreamLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class HiDreamLoRAModelLoader( @@ -33,6 +34,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> HiDreamModel | None: base_model_loader = HiDreamModelLoader() lora_model_loader = HiDreamLoRALoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/HunyuanVideoEmbeddingModelLoader.py b/modules/modelLoader/HunyuanVideoEmbeddingModelLoader.py index d2e28c9b7..eeebf2f98 100644 --- a/modules/modelLoader/HunyuanVideoEmbeddingModelLoader.py +++ b/modules/modelLoader/HunyuanVideoEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class HunyuanVideoEmbeddingModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> HunyuanVideoModel | None: base_model_loader = HunyuanVideoModelLoader() embedding_loader = HunyuanVideoEmbeddingLoader() @@ -41,7 +43,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/HunyuanVideoFineTuneModelLoader.py b/modules/modelLoader/HunyuanVideoFineTuneModelLoader.py index b2ae057ff..f81c8f94b 100644 --- a/modules/modelLoader/HunyuanVideoFineTuneModelLoader.py +++ b/modules/modelLoader/HunyuanVideoFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class HunyuanVideoFineTuneModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> HunyuanVideoModel | None: base_model_loader = HunyuanVideoModelLoader() embedding_loader = HunyuanVideoEmbeddingLoader() @@ -41,7 +43,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/HunyuanVideoLoRAModelLoader.py b/modules/modelLoader/HunyuanVideoLoRAModelLoader.py index 2aa4555c2..a88b8d836 100644 --- a/modules/modelLoader/HunyuanVideoLoRAModelLoader.py +++ b/modules/modelLoader/HunyuanVideoLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class HunyuanVideoLoRAModelLoader( @@ -33,6 +34,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> HunyuanVideoModel | None: base_model_loader = HunyuanVideoModelLoader() lora_model_loader = HunyuanVideoLoRALoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/PixArtAlphaEmbeddingModelLoader.py b/modules/modelLoader/PixArtAlphaEmbeddingModelLoader.py index 52b00ca26..1126589b1 100644 --- a/modules/modelLoader/PixArtAlphaEmbeddingModelLoader.py +++ b/modules/modelLoader/PixArtAlphaEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class PixArtAlphaEmbeddingModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> PixArtAlphaModel | None: base_model_loader = PixArtAlphaModelLoader() embedding_loader = PixArtAlphaEmbeddingLoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/PixArtAlphaFineTuneModelLoader.py b/modules/modelLoader/PixArtAlphaFineTuneModelLoader.py index 45c52c331..1f2ff7de8 100644 --- a/modules/modelLoader/PixArtAlphaFineTuneModelLoader.py +++ b/modules/modelLoader/PixArtAlphaFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class PixArtAlphaFineTuneModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> PixArtAlphaModel | None: base_model_loader = PixArtAlphaModelLoader() embedding_loader = PixArtAlphaEmbeddingLoader() @@ -43,7 +45,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/PixArtAlphaLoRAModelLoader.py b/modules/modelLoader/PixArtAlphaLoRAModelLoader.py index 028f7289f..27b42e96f 100644 --- a/modules/modelLoader/PixArtAlphaLoRAModelLoader.py +++ b/modules/modelLoader/PixArtAlphaLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class PixArtAlphaLoRAModelLoader( @@ -35,6 +36,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> PixArtAlphaModel | None: base_model_loader = PixArtAlphaModelLoader() lora_model_loader = PixArtAlphaLoRALoader() @@ -45,7 +47,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/QwenFineTuneModelLoader.py b/modules/modelLoader/QwenFineTuneModelLoader.py index 0172d2b79..3e6d89709 100644 --- a/modules/modelLoader/QwenFineTuneModelLoader.py +++ b/modules/modelLoader/QwenFineTuneModelLoader.py @@ -6,6 +6,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class QwenFineTuneModelLoader( @@ -31,6 +32,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> QwenModel | None: base_model_loader = QwenModelLoader() model = QwenModel(model_type=model_type) @@ -38,6 +40,6 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) return model diff --git a/modules/modelLoader/QwenLoRAModelLoader.py b/modules/modelLoader/QwenLoRAModelLoader.py index 41003cd07..3ae11114c 100644 --- a/modules/modelLoader/QwenLoRAModelLoader.py +++ b/modules/modelLoader/QwenLoRAModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class QwenLoRAModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> QwenModel | None: base_model_loader = QwenModelLoader() lora_model_loader = QwenLoRALoader() @@ -41,7 +43,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) return model diff --git a/modules/modelLoader/SanaEmbeddingModelLoader.py b/modules/modelLoader/SanaEmbeddingModelLoader.py index 6781d2ba9..7215fe7c7 100644 --- a/modules/modelLoader/SanaEmbeddingModelLoader.py +++ b/modules/modelLoader/SanaEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class SanaEmbeddingModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> SanaModel | None: base_model_loader = SanaModelLoader() embedding_loader = SanaEmbeddingLoader() @@ -41,7 +43,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/SanaFineTuneModelLoader.py b/modules/modelLoader/SanaFineTuneModelLoader.py index 82b14e49c..a713cb555 100644 --- a/modules/modelLoader/SanaFineTuneModelLoader.py +++ b/modules/modelLoader/SanaFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class SanaFineTuneModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> SanaModel | None: base_model_loader = SanaModelLoader() embedding_loader = SanaEmbeddingLoader() @@ -41,7 +43,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/SanaLoRAModelLoader.py b/modules/modelLoader/SanaLoRAModelLoader.py index bfaf8a1da..f326cb6fa 100644 --- a/modules/modelLoader/SanaLoRAModelLoader.py +++ b/modules/modelLoader/SanaLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class SanaLoRAModelLoader( @@ -33,6 +34,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> SanaModel | None: base_model_loader = SanaModelLoader() lora_model_loader = SanaLoRALoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/StableDiffusion3EmbeddingModelLoader.py b/modules/modelLoader/StableDiffusion3EmbeddingModelLoader.py index 86d0ff771..4cdbddff2 100644 --- a/modules/modelLoader/StableDiffusion3EmbeddingModelLoader.py +++ b/modules/modelLoader/StableDiffusion3EmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusion3EmbeddingModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusion3Model | None: base_model_loader = StableDiffusion3ModelLoader() embedding_loader = StableDiffusion3EmbeddingLoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/StableDiffusion3FineTuneModelLoader.py b/modules/modelLoader/StableDiffusion3FineTuneModelLoader.py index 8a9ee4477..a3fd19add 100644 --- a/modules/modelLoader/StableDiffusion3FineTuneModelLoader.py +++ b/modules/modelLoader/StableDiffusion3FineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusion3FineTuneModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusion3Model | None: base_model_loader = StableDiffusion3ModelLoader() embedding_loader = StableDiffusion3EmbeddingLoader() @@ -43,7 +45,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/StableDiffusion3LoRAModelLoader.py b/modules/modelLoader/StableDiffusion3LoRAModelLoader.py index 994d87518..34a53024d 100644 --- a/modules/modelLoader/StableDiffusion3LoRAModelLoader.py +++ b/modules/modelLoader/StableDiffusion3LoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusion3LoRAModelLoader( @@ -35,6 +36,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusion3Model | None: base_model_loader = StableDiffusion3ModelLoader() lora_model_loader = StableDiffusion3LoRALoader() @@ -45,7 +47,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/StableDiffusionEmbeddingModelLoader.py b/modules/modelLoader/StableDiffusionEmbeddingModelLoader.py index e698eb3d5..d81bb673d 100644 --- a/modules/modelLoader/StableDiffusionEmbeddingModelLoader.py +++ b/modules/modelLoader/StableDiffusionEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusionEmbeddingModelLoader( @@ -46,6 +47,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusionModel | None: base_model_loader = StableDiffusionModelLoader() embedding_loader = StableDiffusionEmbeddingLoader() @@ -55,7 +57,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/StableDiffusionFineTuneModelLoader.py b/modules/modelLoader/StableDiffusionFineTuneModelLoader.py index acfdefef0..adc5448fc 100644 --- a/modules/modelLoader/StableDiffusionFineTuneModelLoader.py +++ b/modules/modelLoader/StableDiffusionFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusionFineTuneModelLoader( @@ -46,6 +47,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusionModel | None: base_model_loader = StableDiffusionModelLoader() embedding_loader = StableDiffusionEmbeddingLoader() @@ -55,7 +57,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/StableDiffusionLoRAModelLoader.py b/modules/modelLoader/StableDiffusionLoRAModelLoader.py index ff6503aa1..958a684b5 100644 --- a/modules/modelLoader/StableDiffusionLoRAModelLoader.py +++ b/modules/modelLoader/StableDiffusionLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusionLoRAModelLoader( @@ -47,6 +48,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusionModel | None: base_model_loader = StableDiffusionModelLoader() lora_model_loader = StableDiffusionLoRALoader() @@ -57,7 +59,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/StableDiffusionXLEmbeddingModelLoader.py b/modules/modelLoader/StableDiffusionXLEmbeddingModelLoader.py index 201e993cb..bec654875 100644 --- a/modules/modelLoader/StableDiffusionXLEmbeddingModelLoader.py +++ b/modules/modelLoader/StableDiffusionXLEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusionXLEmbeddingModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusionXLModel | None: base_model_loader = StableDiffusionXLModelLoader() embedding_loader = StableDiffusionXLEmbeddingLoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/StableDiffusionXLFineTuneModelLoader.py b/modules/modelLoader/StableDiffusionXLFineTuneModelLoader.py index 1e1efdc30..2f7e09ea3 100644 --- a/modules/modelLoader/StableDiffusionXLFineTuneModelLoader.py +++ b/modules/modelLoader/StableDiffusionXLFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusionXLFineTuneModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusionXLModel | None: base_model_loader = StableDiffusionXLModelLoader() embedding_loader = StableDiffusionXLEmbeddingLoader() @@ -43,7 +45,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/StableDiffusionXLLoRAModelLoader.py b/modules/modelLoader/StableDiffusionXLLoRAModelLoader.py index 15192bd81..d84cb7a2d 100644 --- a/modules/modelLoader/StableDiffusionXLLoRAModelLoader.py +++ b/modules/modelLoader/StableDiffusionXLLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusionXLLoRAModelLoader( @@ -35,6 +36,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusionXLModel | None: base_model_loader = StableDiffusionXLModelLoader() lora_model_loader = StableDiffusionXLLoRALoader() @@ -45,7 +47,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/WuerstchenEmbeddingModelLoader.py b/modules/modelLoader/WuerstchenEmbeddingModelLoader.py index e10e7b6de..3cfe1b624 100644 --- a/modules/modelLoader/WuerstchenEmbeddingModelLoader.py +++ b/modules/modelLoader/WuerstchenEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class WuerstchenEmbeddingModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> WuerstchenModel | None: base_model_loader = WuerstchenModelLoader() embedding_loader = WuerstchenEmbeddingLoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/WuerstchenFineTuneModelLoader.py b/modules/modelLoader/WuerstchenFineTuneModelLoader.py index 0182b31a0..df8ac40b1 100644 --- a/modules/modelLoader/WuerstchenFineTuneModelLoader.py +++ b/modules/modelLoader/WuerstchenFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class WuerstchenFineTuneModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> WuerstchenModel | None: base_model_loader = WuerstchenModelLoader() embedding_loader = WuerstchenEmbeddingLoader() @@ -43,7 +45,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/WuerstchenLoRAModelLoader.py b/modules/modelLoader/WuerstchenLoRAModelLoader.py index fcd9f91ab..51d46760f 100644 --- a/modules/modelLoader/WuerstchenLoRAModelLoader.py +++ b/modules/modelLoader/WuerstchenLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class WuerstchenLoRAModelLoader( @@ -35,6 +36,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> WuerstchenModel | None: base_model_loader = WuerstchenModelLoader() lora_model_loader = WuerstchenLoRALoader() @@ -45,7 +47,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/chroma/ChromaModelLoader.py b/modules/modelLoader/chroma/ChromaModelLoader.py index 1424e566a..daeb9331f 100644 --- a/modules/modelLoader/chroma/ChromaModelLoader.py +++ b/modules/modelLoader/chroma/ChromaModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter import torch @@ -33,10 +34,11 @@ def __load_internal( base_model_name: str, transformer_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): self.__load_diffusers( - model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name, + model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name, quant_filters, ) else: raise Exception("not an internal model") @@ -49,6 +51,7 @@ def __load_diffusers( base_model_name: str, transformer_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): diffusers_sub = [] if not transformer_model_name: @@ -104,7 +107,7 @@ def __load_diffusers( quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer == DataType.GGUF else None, ) transformer = self._convert_diffusers_sub_module_to_dtype( - transformer, weight_dtypes.transformer, weight_dtypes.train_dtype + transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters, ) else: transformer = self._load_diffusers_sub_module( @@ -113,6 +116,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "transformer", + quant_filters, ) model.model_type = model_type @@ -130,6 +134,7 @@ def __load_safetensors( base_model_name: str, transformer_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): #no single file .safetensors for Chroma available at the time of writing this code raise NotImplementedError("Loading of single file Chroma models not supported. Use the diffusers model instead. Optionally, transformer-only safetensor files can be loaded by overriding the transformer.") @@ -140,12 +145,13 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] try: self.__load_internal( - model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters, ) return except Exception: @@ -153,7 +159,7 @@ def load( try: self.__load_diffusers( - model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters, ) return except Exception: @@ -161,7 +167,7 @@ def load( try: self.__load_safetensors( - model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters, ) return except Exception: diff --git a/modules/modelLoader/flux/FluxModelLoader.py b/modules/modelLoader/flux/FluxModelLoader.py index 254f9a3ee..cda83c7e1 100644 --- a/modules/modelLoader/flux/FluxModelLoader.py +++ b/modules/modelLoader/flux/FluxModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter import torch @@ -36,11 +37,12 @@ def __load_internal( vae_model_name: str, include_text_encoder_1: bool, include_text_encoder_2: bool, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): self.__load_diffusers( model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name, - include_text_encoder_1, include_text_encoder_2, + include_text_encoder_1, include_text_encoder_2, quant_filters, ) else: raise Exception("not an internal model") @@ -55,6 +57,7 @@ def __load_diffusers( vae_model_name: str, include_text_encoder_1: bool, include_text_encoder_2: bool, + quant_filters: list[ModuleFilter], ): diffusers_sub = [] transformers_sub = [] @@ -140,7 +143,7 @@ def __load_diffusers( quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer == DataType.GGUF else None, ) transformer = self._convert_diffusers_sub_module_to_dtype( - transformer, weight_dtypes.transformer, weight_dtypes.train_dtype + transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters, ) else: transformer = self._load_diffusers_sub_module( @@ -149,6 +152,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "transformer", + quant_filters, ) model.model_type = model_type @@ -170,6 +174,7 @@ def __load_safetensors( vae_model_name: str, include_text_encoder_1: bool, include_text_encoder_2: bool, + quant_filters: list[ModuleFilter], ): transformer = FluxTransformer2DModel.from_single_file( #always load transformer separately even though FluxPipeLine.from_single_file() could load it, to avoid loading in float32: @@ -222,7 +227,7 @@ def __load_safetensors( print("text encoder 2 (t5) not loaded, continuing without it") transformer = self._convert_diffusers_sub_module_to_dtype( - pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype + pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters, ) model.model_type = model_type @@ -240,13 +245,14 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] try: self.__load_internal( model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, - model_names.include_text_encoder, model_names.include_text_encoder_2, + model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters, ) return except Exception: @@ -255,7 +261,7 @@ def load( try: self.__load_diffusers( model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, - model_names.include_text_encoder, model_names.include_text_encoder_2, + model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters, ) return except Exception: @@ -264,7 +270,7 @@ def load( try: self.__load_safetensors( model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, - model_names.include_text_encoder, model_names.include_text_encoder_2, + model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters, ) return except Exception: diff --git a/modules/modelLoader/hiDream/HiDreamModelLoader.py b/modules/modelLoader/hiDream/HiDreamModelLoader.py index 92bce0a9a..e6e5ca462 100644 --- a/modules/modelLoader/hiDream/HiDreamModelLoader.py +++ b/modules/modelLoader/hiDream/HiDreamModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter from diffusers import ( AutoencoderKL, @@ -42,11 +43,12 @@ def __load_internal( include_text_encoder_2: bool, include_text_encoder_3: bool, include_text_encoder_4: bool, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): self.__load_diffusers( model, model_type, weight_dtypes, base_model_name, text_encoder_4_model_name, vae_model_name, - include_text_encoder_1, include_text_encoder_2, include_text_encoder_3, include_text_encoder_4, + include_text_encoder_1, include_text_encoder_2, include_text_encoder_3, include_text_encoder_4, quant_filters, ) else: raise Exception("not an internal model") @@ -63,6 +65,7 @@ def __load_diffusers( include_text_encoder_2: bool, include_text_encoder_3: bool, include_text_encoder_4: bool, + quant_filters: list[ModuleFilter], ): diffusers_sub = [] transformers_sub = [] @@ -191,6 +194,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "transformer", + quant_filters, ) model.model_type = model_type @@ -218,6 +222,7 @@ def __load_safetensors( include_text_encoder_2: bool, include_text_encoder_3: bool, include_text_encoder_4: bool, + quant_filters: list[ModuleFilter], ): pipeline = HiDreamImagePipeline.from_single_file( pretrained_model_link_or_path=base_model_name, @@ -264,7 +269,7 @@ def __load_safetensors( print("text encoder 2 (t5) not loaded, continuing without it") transformer = self._convert_diffusers_sub_module_to_dtype( - pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype + pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters, ) model.model_type = model_type @@ -290,6 +295,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] @@ -298,7 +304,7 @@ def load( model, model_type, weight_dtypes, model_names.base_model, model_names.text_encoder_4, model_names.vae_model, model_names.include_text_encoder, model_names.include_text_encoder_2, - model_names.include_text_encoder_3, model_names.include_text_encoder_4, + model_names.include_text_encoder_3, model_names.include_text_encoder_4, quant_filters, ) self.__after_load(model) return @@ -310,7 +316,7 @@ def load( model, model_type, weight_dtypes, model_names.base_model, model_names.text_encoder_4, model_names.vae_model, model_names.include_text_encoder, model_names.include_text_encoder_2, - model_names.include_text_encoder_3, model_names.include_text_encoder_4, + model_names.include_text_encoder_3, model_names.include_text_encoder_4, quant_filters, ) self.__after_load(model) return @@ -322,7 +328,7 @@ def load( model, model_type, weight_dtypes, model_names.base_model, model_names.text_encoder_4, model_names.vae_model, model_names.include_text_encoder, model_names.include_text_encoder_2, - model_names.include_text_encoder_3, model_names.include_text_encoder_4, + model_names.include_text_encoder_3, model_names.include_text_encoder_4, quant_filters, ) self.__after_load(model) return diff --git a/modules/modelLoader/hunyuanVideo/HunyuanVideoModelLoader.py b/modules/modelLoader/hunyuanVideo/HunyuanVideoModelLoader.py index fe044ca03..e7e8996fb 100644 --- a/modules/modelLoader/hunyuanVideo/HunyuanVideoModelLoader.py +++ b/modules/modelLoader/hunyuanVideo/HunyuanVideoModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter from diffusers import ( AutoencoderKLHunyuanVideo, @@ -32,11 +33,12 @@ def __load_internal( vae_model_name: str, include_text_encoder_1: bool, include_text_encoder_2: bool, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): self.__load_diffusers( model, model_type, weight_dtypes, base_model_name, vae_model_name, - include_text_encoder_1, include_text_encoder_2, + include_text_encoder_1, include_text_encoder_2, quant_filters, ) else: raise Exception("not an internal model") @@ -50,6 +52,7 @@ def __load_diffusers( vae_model_name: str, include_text_encoder_1: bool, include_text_encoder_2: bool, + quant_filters: list[ModuleFilter], ): diffusers_sub = [] transformers_sub = [] @@ -133,6 +136,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "transformer", + quant_filters, ) model.model_type = model_type @@ -153,6 +157,7 @@ def __load_safetensors( vae_model_name: str, include_text_encoder_1: bool, include_text_encoder_2: bool, + quant_filters: list[ModuleFilter], ): pipeline = HunyuanVideoPipeline.from_single_file( pretrained_model_link_or_path=base_model_name, @@ -192,7 +197,7 @@ def __load_safetensors( print("text encoder 2 (clip l) not loaded, continuing without it") transformer = self._convert_diffusers_sub_module_to_dtype( - pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype + pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters, ) model.model_type = model_type @@ -214,13 +219,14 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] try: self.__load_internal( model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, - model_names.include_text_encoder, model_names.include_text_encoder_2, + model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters, ) self.__after_load(model) return @@ -230,7 +236,7 @@ def load( try: self.__load_diffusers( model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, - model_names.include_text_encoder, model_names.include_text_encoder_2, + model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters, ) self.__after_load(model) return @@ -240,7 +246,7 @@ def load( try: self.__load_safetensors( model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, - model_names.include_text_encoder, model_names.include_text_encoder_2, + model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters, ) self.__after_load(model) return diff --git a/modules/modelLoader/mixin/HFModelLoaderMixin.py b/modules/modelLoader/mixin/HFModelLoaderMixin.py index 163702417..0cb3d3c92 100644 --- a/modules/modelLoader/mixin/HFModelLoaderMixin.py +++ b/modules/modelLoader/mixin/HFModelLoaderMixin.py @@ -6,6 +6,7 @@ from itertools import repeat from modules.util.enum.DataType import DataType +from modules.util.ModuleFilter import ModuleFilter from modules.util.quantization_util import ( is_quantized_parameter, replace_linear_with_fp8_layers, @@ -32,6 +33,7 @@ def __load_sub_module( dtype: DataType, train_dtype: DataType, keep_in_fp32_modules: list[str] | None, + quant_filters: list[ModuleFilter] | None, pretrained_model_name_or_path: str, subfolder: str | None, model_filename: str, @@ -43,11 +45,11 @@ def __load_sub_module( with accelerate.init_empty_weights(): if dtype.quantize_nf4(): - replace_linear_with_nf4_layers(sub_module, keep_in_fp32_modules, copy_parameters=False) + replace_linear_with_nf4_layers(sub_module, keep_in_fp32_modules, quant_filters, copy_parameters=False) elif dtype.quantize_int8(): - replace_linear_with_int8_layers(sub_module, keep_in_fp32_modules, copy_parameters=False) + replace_linear_with_int8_layers(sub_module, keep_in_fp32_modules, quant_filters, copy_parameters=False) elif dtype.quantize_fp8(): - replace_linear_with_fp8_layers(sub_module, keep_in_fp32_modules, copy_parameters=False) + replace_linear_with_fp8_layers(sub_module, keep_in_fp32_modules, quant_filters, copy_parameters=False) is_local = os.path.isdir(pretrained_model_name_or_path) @@ -195,6 +197,7 @@ def _load_transformers_sub_module( dtype=dtype, train_dtype=train_dtype, keep_in_fp32_modules=module_type._keep_in_fp32_modules, + quant_filters=None, pretrained_model_name_or_path=pretrained_model_name_or_path, subfolder=subfolder, model_filename="model.safetensors", @@ -209,6 +212,7 @@ def _load_diffusers_sub_module( train_dtype: DataType, pretrained_model_name_or_path: str, subfolder: str | None = None, + quant_filters: list[ModuleFilter] | None = None, ): user_agent = { "file_type": "model", @@ -230,6 +234,7 @@ def _load_diffusers_sub_module( dtype=dtype, train_dtype=train_dtype, keep_in_fp32_modules=module_type._keep_in_fp32_modules, + quant_filters=quant_filters, pretrained_model_name_or_path=pretrained_model_name_or_path, subfolder=subfolder, model_filename="diffusion_pytorch_model.safetensors", @@ -243,16 +248,17 @@ def __convert_sub_module_to_dtype( dtype: DataType, train_dtype: DataType, keep_in_fp32_modules: list[str] | None, + quant_filters: list[ModuleFilter] | None, ): if keep_in_fp32_modules is None: keep_in_fp32_modules = [] if dtype.quantize_nf4(): - replace_linear_with_nf4_layers(sub_module, keep_in_fp32_modules, copy_parameters=True) + replace_linear_with_nf4_layers(sub_module, keep_in_fp32_modules, quant_filters, copy_parameters=True) elif dtype.quantize_int8(): - replace_linear_with_int8_layers(sub_module, keep_in_fp32_modules, copy_parameters=True) + replace_linear_with_int8_layers(sub_module, keep_in_fp32_modules, quant_filters, copy_parameters=True) elif dtype.quantize_fp8(): - replace_linear_with_fp8_layers(sub_module, keep_in_fp32_modules, copy_parameters=True) + replace_linear_with_fp8_layers(sub_module, keep_in_fp32_modules, quant_filters, copy_parameters=True) for module_name, module in sub_module.named_modules(): param_iter = [(x, y[0], y[1]) for x, y in zip(repeat(False), module._parameters.items(), strict=False)] @@ -281,6 +287,7 @@ def _convert_transformers_sub_module_to_dtype( sub_module: nn.Module, dtype: DataType, train_dtype: DataType, + quant_filters: list[ModuleFilter] | None = None, ): module_type = type(sub_module) @@ -289,6 +296,7 @@ def _convert_transformers_sub_module_to_dtype( dtype, train_dtype, module_type._keep_in_fp32_modules, + quant_filters, ) def _convert_diffusers_sub_module_to_dtype( @@ -296,12 +304,14 @@ def _convert_diffusers_sub_module_to_dtype( sub_module: nn.Module, dtype: DataType, train_dtype: DataType, + quant_filters: list[ModuleFilter] | None = None, ): return self.__convert_sub_module_to_dtype( sub_module, dtype, train_dtype, None, + quant_filters, ) def _prepare_sub_modules(self, pretrained_model_name_or_path: str, diffusers_modules: list[str], transformers_modules: list[str]): diff --git a/modules/modelLoader/pixartAlpha/PixArtAlphaModelLoader.py b/modules/modelLoader/pixartAlpha/PixArtAlphaModelLoader.py index 1428a4079..aef761608 100644 --- a/modules/modelLoader/pixartAlpha/PixArtAlphaModelLoader.py +++ b/modules/modelLoader/pixartAlpha/PixArtAlphaModelLoader.py @@ -6,6 +6,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter from diffusers import AutoencoderKL, DDIMScheduler, Transformer2DModel from transformers import T5EncoderModel, T5Tokenizer @@ -24,9 +25,10 @@ def __load_internal( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): - self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, vae_model_name) + self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, vae_model_name, quant_filters) else: raise Exception("not an internal model") @@ -37,6 +39,7 @@ def __load_diffusers( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): tokenizer = T5Tokenizer.from_pretrained( base_model_name, @@ -78,6 +81,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "transformer", + quant_filters, ) model.model_type = model_type @@ -93,19 +97,20 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> PixArtAlphaModel | None: stacktraces = [] base_model_name = model_names.base_model try: - self.__load_internal(model, model_type, weight_dtypes, base_model_name, model_names.vae_model) + self.__load_internal(model, model_type, weight_dtypes, base_model_name, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) try: - self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, model_names.vae_model) + self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) diff --git a/modules/modelLoader/qwen/QwenModelLoader.py b/modules/modelLoader/qwen/QwenModelLoader.py index 85b5817f1..377710258 100644 --- a/modules/modelLoader/qwen/QwenModelLoader.py +++ b/modules/modelLoader/qwen/QwenModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter import torch @@ -33,10 +34,11 @@ def __load_internal( base_model_name: str, transformer_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): self.__load_diffusers( - model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name, + model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name, quant_filters, ) else: raise Exception("not an internal model") @@ -49,6 +51,7 @@ def __load_diffusers( base_model_name: str, transformer_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): diffusers_sub = [] if not transformer_model_name: @@ -106,7 +109,7 @@ def __load_diffusers( quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer == DataType.GGUF else None, ) transformer = self._convert_diffusers_sub_module_to_dtype( - transformer, weight_dtypes.transformer, weight_dtypes.train_dtype + transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters, ) else: transformer = self._load_diffusers_sub_module( @@ -115,6 +118,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "transformer", + quant_filters, ) model.model_type = model_type @@ -132,6 +136,7 @@ def __load_safetensors( base_model_name: str, transformer_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): #no single file .safetensors for Qwen available at the time of writing this code raise NotImplementedError("Loading of single file Qwen models not supported. Use the diffusers model instead. Optionally, transformer-only safetensor files can be loaded by overriding the transformer.") @@ -142,12 +147,13 @@ def load( #TODO share code between models model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] try: self.__load_internal( - model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters, ) return except Exception: @@ -155,7 +161,7 @@ def load( #TODO share code between models try: self.__load_diffusers( - model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters, ) return except Exception: @@ -163,7 +169,7 @@ def load( #TODO share code between models try: self.__load_safetensors( - model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters, ) return except Exception: diff --git a/modules/modelLoader/sana/SanaModelLoader.py b/modules/modelLoader/sana/SanaModelLoader.py index e4c2d638f..0c2b5fc4e 100644 --- a/modules/modelLoader/sana/SanaModelLoader.py +++ b/modules/modelLoader/sana/SanaModelLoader.py @@ -6,6 +6,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter from diffusers import AutoencoderDC, DPMSolverMultistepScheduler, SanaTransformer2DModel from transformers import Gemma2Model, GemmaTokenizer @@ -24,9 +25,10 @@ def __load_internal( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): - self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, vae_model_name) + self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, vae_model_name, quant_filters) else: raise Exception("not an internal model") @@ -37,6 +39,7 @@ def __load_diffusers( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): tokenizer = GemmaTokenizer.from_pretrained( base_model_name, @@ -78,6 +81,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "transformer", + quant_filters, ) model.model_type = model_type @@ -93,19 +97,20 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> SanaModel | None: stacktraces = [] base_model_name = model_names.base_model try: - self.__load_internal(model, model_type, weight_dtypes, base_model_name, model_names.vae_model) + self.__load_internal(model, model_type, weight_dtypes, base_model_name, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) try: - self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, model_names.vae_model) + self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) diff --git a/modules/modelLoader/stableDiffusion/StableDiffusionModelLoader.py b/modules/modelLoader/stableDiffusion/StableDiffusionModelLoader.py index 2e55c8b5a..6aee93334 100644 --- a/modules/modelLoader/stableDiffusion/StableDiffusionModelLoader.py +++ b/modules/modelLoader/stableDiffusion/StableDiffusionModelLoader.py @@ -9,6 +9,7 @@ from modules.util.enum.NoiseScheduler import NoiseScheduler from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter import torch @@ -55,9 +56,10 @@ def __load_internal( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): - self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, vae_model_name) + self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, vae_model_name, quant_filters) else: raise Exception("not an internal model") @@ -68,6 +70,7 @@ def __load_diffusers( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): tokenizer = CLIPTokenizer.from_pretrained( base_model_name, @@ -113,6 +116,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "unet", + quant_filters, ) image_depth_processor = DPTImageProcessor.from_pretrained( @@ -158,6 +162,7 @@ def __load_ckpt( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): state_dict = torch.load(base_model_name, weights_only=True) state_dict = self.__fix_nai_model(state_dict) @@ -196,7 +201,7 @@ def __load_ckpt( pipeline.text_encoder, weight_dtypes.text_encoder, weight_dtypes.train_dtype ) unet = self._convert_diffusers_sub_module_to_dtype( - pipeline.unet, weight_dtypes.unet, weight_dtypes.train_dtype + pipeline.unet, weight_dtypes.unet, weight_dtypes.train_dtype, quant_filters, ) model.model_type = model_type @@ -215,6 +220,7 @@ def __load_safetensors( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): num_in_channels = 4 if model_type.has_mask_input(): @@ -251,7 +257,7 @@ def __load_safetensors( pipeline.text_encoder, weight_dtypes.text_encoder, weight_dtypes.train_dtype ) unet = self._convert_diffusers_sub_module_to_dtype( - pipeline.unet, weight_dtypes.unet, weight_dtypes.train_dtype + pipeline.unet, weight_dtypes.unet, weight_dtypes.train_dtype, quant_filters, ) model.model_type = model_type @@ -269,6 +275,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] @@ -276,19 +283,19 @@ def load( model.sd_config_filename = self._get_sd_config_name(model_type, model_names.base_model) try: - self.__load_internal(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model) + self.__load_internal(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) try: - self.__load_diffusers(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model) + self.__load_diffusers(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) try: - self.__load_safetensors(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model) + self.__load_safetensors(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) diff --git a/modules/modelLoader/stableDiffusion3/StableDiffusion3ModelLoader.py b/modules/modelLoader/stableDiffusion3/StableDiffusion3ModelLoader.py index f5fa6871f..0852b20a9 100644 --- a/modules/modelLoader/stableDiffusion3/StableDiffusion3ModelLoader.py +++ b/modules/modelLoader/stableDiffusion3/StableDiffusion3ModelLoader.py @@ -6,6 +6,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5Tokenizer @@ -27,11 +28,12 @@ def __load_internal( include_text_encoder_1: bool, include_text_encoder_2: bool, include_text_encoder_3: bool, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): self.__load_diffusers( model, model_type, weight_dtypes, base_model_name, vae_model_name, - include_text_encoder_1, include_text_encoder_2, include_text_encoder_3, + include_text_encoder_1, include_text_encoder_2, include_text_encoder_3, quant_filters, ) else: raise Exception("not an internal model") @@ -46,6 +48,7 @@ def __load_diffusers( include_text_encoder_1: bool, include_text_encoder_2: bool, include_text_encoder_3: bool, + quant_filters: list[ModuleFilter], ): #no call to self._prepare_sub_modules, because SAI polluted their sd3 / sd3.5 medium repo text encoders with fp16 files @@ -133,6 +136,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "transformer", + quant_filters, ) model.model_type = model_type @@ -156,6 +160,7 @@ def __load_safetensors( include_text_encoder_1: bool, include_text_encoder_2: bool, include_text_encoder_3: bool, + quant_filters: list[ModuleFilter], ): pipeline = StableDiffusion3Pipeline.from_single_file( pretrained_model_link_or_path=base_model_name, @@ -220,7 +225,7 @@ def __load_safetensors( print("text encoder 3 (t5) not loaded, continuing without it") transformer = self._convert_diffusers_sub_module_to_dtype( - pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype + pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters, ) model.model_type = model_type @@ -240,6 +245,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] @@ -247,7 +253,7 @@ def load( self.__load_internal( model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, model_names.include_text_encoder, model_names.include_text_encoder_2, - model_names.include_text_encoder_3, + model_names.include_text_encoder_3, quant_filters, ) return except Exception: @@ -257,7 +263,7 @@ def load( self.__load_diffusers( model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, model_names.include_text_encoder, model_names.include_text_encoder_2, - model_names.include_text_encoder_3, + model_names.include_text_encoder_3, quant_filters, ) return except Exception: @@ -267,7 +273,7 @@ def load( self.__load_safetensors( model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, model_names.include_text_encoder, model_names.include_text_encoder_2, - model_names.include_text_encoder_3, + model_names.include_text_encoder_3, quant_filters, ) return except Exception: diff --git a/modules/modelLoader/stableDiffusionXL/StableDiffusionXLModelLoader.py b/modules/modelLoader/stableDiffusionXL/StableDiffusionXLModelLoader.py index 75bf33d98..4ffb61777 100644 --- a/modules/modelLoader/stableDiffusionXL/StableDiffusionXLModelLoader.py +++ b/modules/modelLoader/stableDiffusionXL/StableDiffusionXLModelLoader.py @@ -9,6 +9,7 @@ from modules.util.enum.NoiseScheduler import NoiseScheduler from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter from diffusers import ( AutoencoderKL, @@ -46,9 +47,10 @@ def __load_internal( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): - self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, vae_model_name) + self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, vae_model_name, quant_filters) else: raise Exception("not an internal model") @@ -59,6 +61,7 @@ def __load_diffusers( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): tokenizer_1 = CLIPTokenizer.from_pretrained( base_model_name, @@ -117,6 +120,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "unet", + quant_filters, ) model.model_type = model_type @@ -135,7 +139,11 @@ def __load_ckpt( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): + if quant_filters is not None and len(quant_filters) > 0: + raise NotImplementedError("Quantization not implemented for loading ckpt files") + pipeline = StableDiffusionXLPipeline.from_single_file( pretrained_model_link_or_path=base_model_name, original_config=model.sd_config_filename, @@ -176,6 +184,7 @@ def __load_safetensors( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): if model_type.has_conditioning_image_input(): pipeline = StableDiffusionXLInpaintPipeline.from_single_file( @@ -216,7 +225,7 @@ def __load_safetensors( pipeline.text_encoder_2, weight_dtypes.text_encoder_2, weight_dtypes.train_dtype ) unet = self._convert_diffusers_sub_module_to_dtype( - pipeline.unet, weight_dtypes.unet, weight_dtypes.train_dtype + pipeline.unet, weight_dtypes.unet, weight_dtypes.train_dtype, quant_filters, ) model.model_type = model_type @@ -234,6 +243,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] @@ -241,19 +251,19 @@ def load( model.sd_config_filename = self._get_sd_config_name(model_type, model_names.base_model) try: - self.__load_internal(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model) + self.__load_internal(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) try: - self.__load_diffusers(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model) + self.__load_diffusers(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) try: - self.__load_safetensors(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model) + self.__load_safetensors(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) diff --git a/modules/modelLoader/wuerstchen/WuerstchenModelLoader.py b/modules/modelLoader/wuerstchen/WuerstchenModelLoader.py index c7c237cc7..d3419706f 100644 --- a/modules/modelLoader/wuerstchen/WuerstchenModelLoader.py +++ b/modules/modelLoader/wuerstchen/WuerstchenModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter from diffusers import DDPMWuerstchenScheduler from diffusers.models import StableCascadeUNet @@ -32,6 +33,7 @@ def __load_internal( prior_model_name: str, effnet_encoder_model_name: str, decoder_model_name: str, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(prior_model_name, "meta.json")): self.__load_diffusers( @@ -42,6 +44,7 @@ def __load_internal( "", # pass an empty prior name, so it's always loaded from the backup effnet_encoder_model_name, decoder_model_name, + quant_filters, ) else: raise Exception("not an internal model") @@ -55,6 +58,7 @@ def __load_diffusers( prior_prior_model_name: str, effnet_encoder_model_name: str, decoder_model_name: str, + quant_filters: list[ModuleFilter], ): if model_type.is_wuerstchen_v2(): decoder_tokenizer = CLIPTokenizer.from_pretrained( @@ -127,6 +131,7 @@ def __load_diffusers( weight_dtypes.train_dtype, prior_model_name, "prior", + quant_filters, ) elif model_type.is_stable_cascade(): if prior_prior_model_name: @@ -140,7 +145,7 @@ def __load_diffusers( prior_prior = StableCascadeUNet(**prior_config) prior_prior.load_state_dict(convert_stable_cascade_ckpt_to_diffusers(load_file(prior_prior_model_name))) prior_prior = self._convert_diffusers_sub_module_to_dtype( - prior_prior, weight_dtypes.prior, weight_dtypes.fallback_train_dtype + prior_prior, weight_dtypes.prior, weight_dtypes.fallback_train_dtype, quant_filters, ) else: prior_prior = self._load_diffusers_sub_module( @@ -149,6 +154,7 @@ def __load_diffusers( weight_dtypes.fallback_train_dtype, prior_model_name, "prior", + quant_filters, ) prior_tokenizer = CLIPTokenizer.from_pretrained( @@ -196,6 +202,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] @@ -211,7 +218,7 @@ def load( weight_dtypes, prior_model_name, effnet_encoder_model_name, - decoder_model_name, + decoder_model_name, quant_filters, ) return except Exception: @@ -225,7 +232,7 @@ def load( prior_model_name, prior_prior_model_name, effnet_encoder_model_name, - decoder_model_name, + decoder_model_name, quant_filters, ) return except Exception: diff --git a/modules/trainer/GenericTrainer.py b/modules/trainer/GenericTrainer.py index 7dfa76398..13b45884c 100644 --- a/modules/trainer/GenericTrainer.py +++ b/modules/trainer/GenericTrainer.py @@ -29,6 +29,7 @@ from modules.util.enum.TimeUnit import TimeUnit from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.memory_util import TorchMemoryRecorder +from modules.util.ModuleFilter import ModuleFilter from modules.util.time_util import get_string_timestamp from modules.util.torch_util import torch_gc from modules.util.TrainProgress import TrainProgress @@ -120,10 +121,16 @@ def start(self): ) self.callbacks.on_update_status("loading the model") + + quant_filters = [ + ModuleFilter(pattern, use_regex=self.config.quantization_layer_filter_regex) + for pattern in self.config.quantization_layer_filter.split(",") + ] self.model = self.model_loader.load( model_type=self.config.model_type, model_names=model_names, weight_dtypes=self.config.weight_dtypes(), + quant_filters=quant_filters, ) self.model.train_config = self.config diff --git a/modules/ui/ConvertModelUI.py b/modules/ui/ConvertModelUI.py index 685aeaa85..837b27614 100644 --- a/modules/ui/ConvertModelUI.py +++ b/modules/ui/ConvertModelUI.py @@ -131,6 +131,7 @@ def convert_model(self): base_model=self.convert_model_args.input_name, ), weight_dtypes=self.convert_model_args.weight_dtypes(), + #TODO quantization layer filter ) elif self.convert_model_args.training_method in [TrainingMethod.LORA, TrainingMethod.EMBEDDING]: model = model_loader.load( @@ -140,6 +141,7 @@ def convert_model(self): embedding=EmbeddingName(str(uuid4()), self.convert_model_args.input_name), ), weight_dtypes=self.convert_model_args.weight_dtypes(), + #TODO quantization layer filter ) else: raise Exception("could not load model: " + self.convert_model_args.input_name) diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index 27b7fd34f..8c43a350f 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -1,5 +1,16 @@ from pathlib import Path +from modules.modelSetup.BaseChromaSetup import PRESETS as chroma_presets +from modules.modelSetup.BaseFluxSetup import PRESETS as flux_presets +from modules.modelSetup.BaseHiDreamSetup import PRESETS as hidream_presets +from modules.modelSetup.BaseHunyuanVideoSetup import PRESETS as hunyuan_video_presets +from modules.modelSetup.BasePixArtAlphaSetup import PRESETS as pixart_presets +from modules.modelSetup.BaseQwenSetup import PRESETS as qwen_presets +from modules.modelSetup.BaseSanaSetup import PRESETS as sana_presets +from modules.modelSetup.BaseStableDiffusion3Setup import PRESETS as sd3_presets +from modules.modelSetup.BaseStableDiffusionSetup import PRESETS as sd_presets +from modules.modelSetup.BaseStableDiffusionXLSetup import PRESETS as sdxl_presets +from modules.modelSetup.BaseWuerstchenSetup import PRESETS as sc_presets from modules.util.config.TrainConfig import TrainConfig from modules.util.enum.ConfigPart import ConfigPart from modules.util.enum.DataType import DataType @@ -33,46 +44,101 @@ def refresh_ui(self): self.scroll_frame = ctk.CTkScrollableFrame(self.master, fg_color="transparent") self.scroll_frame.grid(row=0, column=0, sticky="nsew") + self.scroll_frame.grid_columnconfigure(0, weight=1) - self.scroll_frame.grid_columnconfigure(0, weight=0) - self.scroll_frame.grid_columnconfigure(1, weight=10) - self.scroll_frame.grid_columnconfigure(2, minsize=50) - self.scroll_frame.grid_columnconfigure(3, weight=0) - self.scroll_frame.grid_columnconfigure(4, weight=1) + base_frame = ctk.CTkFrame(master=self.scroll_frame, corner_radius=5) + base_frame.grid(row=0, column=0, padx=5, pady=5, sticky="nsew") + + base_frame.grid_columnconfigure(0, weight=0) + base_frame.grid_columnconfigure(1, weight=10)#, minsize=500) + base_frame.grid_columnconfigure(2, minsize=50) + base_frame.grid_columnconfigure(3, weight=0) + base_frame.grid_columnconfigure(4, weight=1) if self.train_config.model_type.is_stable_diffusion(): #TODO simplify - self.__setup_stable_diffusion_ui() + self.__setup_stable_diffusion_ui(base_frame) if self.train_config.model_type.is_stable_diffusion_3(): - self.__setup_stable_diffusion_3_ui() + self.__setup_stable_diffusion_3_ui(base_frame) elif self.train_config.model_type.is_stable_diffusion_xl(): - self.__setup_stable_diffusion_xl_ui() + self.__setup_stable_diffusion_xl_ui(base_frame) elif self.train_config.model_type.is_wuerstchen(): - self.__setup_wuerstchen_ui() + self.__setup_wuerstchen_ui(base_frame) elif self.train_config.model_type.is_pixart(): - self.__setup_pixart_alpha_ui() + self.__setup_pixart_alpha_ui(base_frame) elif self.train_config.model_type.is_flux(): - self.__setup_flux_ui() + self.__setup_flux_ui(base_frame) elif self.train_config.model_type.is_chroma(): - self.__setup_chroma_ui() + self.__setup_chroma_ui(base_frame) elif self.train_config.model_type.is_qwen(): - self.__setup_qwen_ui() + self.__setup_qwen_ui(base_frame) elif self.train_config.model_type.is_sana(): - self.__setup_sana_ui() + self.__setup_sana_ui(base_frame) elif self.train_config.model_type.is_hunyuan_video(): - self.__setup_hunyuan_video_ui() + self.__setup_hunyuan_video_ui(base_frame) elif self.train_config.model_type.is_hi_dream(): - self.__setup_hi_dream_ui() + self.__setup_hi_dream_ui(base_frame) + + self.__create_quantization_frame(self.scroll_frame, row=1, column=0) + + def __create_quantization_frame( + self, + master, + row: int, + column: int, + ): + frame = ctk.CTkFrame(master=master, corner_radius=5, width=300) + frame.grid(row=row, column=column, padx=5, pady=5, sticky="nsew") + frame.grid_columnconfigure(0, weight=1) + frame.grid_columnconfigure(1, weight=10) + + presets = [] + if self.train_config.model_type.is_stable_diffusion(): #TODO simplify and de-duplicate with layer filter on training tab + presets = sd_presets + elif self.train_config.model_type.is_stable_diffusion_xl(): + presets = sdxl_presets + elif self.train_config.model_type.is_stable_diffusion_3(): + presets = sd3_presets + elif self.train_config.model_type.is_wuerstchen(): + presets = sc_presets + elif self.train_config.model_type.is_pixart(): + presets = pixart_presets + elif self.train_config.model_type.is_flux(): + presets = flux_presets + elif self.train_config.model_type.is_qwen(): + presets = qwen_presets + elif self.train_config.model_type.is_chroma(): + presets = chroma_presets + elif self.train_config.model_type.is_sana(): + presets = sana_presets + elif self.train_config.model_type.is_hunyuan_video(): + presets = hunyuan_video_presets + elif self.train_config.model_type.is_hi_dream(): + presets = hidream_presets + else: + presets = {"full": []} + + components.layer_filter_entry(frame, 0, 0, self.ui_state, + preset_var_name="quantization_layer_filter_preset", presets=presets, + preset_label="Quantization Layer Filter", + preset_tooltip="Select a preset defining which layers to quantize. Quantization of certain layers can decrease model quality. Only applies to the transformer/unet", + entry_var_name="quantization_layer_filter", + entry_tooltip="Comma-separated list of layers to quantize. Regular expressions (if toggled) are supported. Any model layer with a matching name will be quantized", + regex_var_name="quantization_layer_filter_regex", + regex_tooltip="If enabled, layer filter patterns are interpreted as regular expressions. Otherwise, simple substring matching is used.", + ) - def __setup_stable_diffusion_ui(self): + def __setup_stable_diffusion_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_unet=True, has_text_encoder=True, has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method in [ @@ -82,10 +148,11 @@ def __setup_stable_diffusion_ui(self): allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_stable_diffusion_3_ui(self): + def __setup_stable_diffusion_3_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_transformer=True, has_text_encoder_1=True, @@ -94,16 +161,18 @@ def __setup_stable_diffusion_3_ui(self): has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_flux_ui(self): + def __setup_flux_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_transformer=True, allow_override_transformer=True, @@ -112,16 +181,18 @@ def __setup_flux_ui(self): has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_chroma_ui(self): + def __setup_chroma_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_transformer=True, allow_override_transformer=True, @@ -129,16 +200,18 @@ def __setup_chroma_ui(self): has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_qwen_ui(self): + def __setup_qwen_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_transformer=True, allow_override_transformer=True, @@ -146,16 +219,18 @@ def __setup_qwen_ui(self): has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_stable_diffusion_xl_ui(self): + def __setup_stable_diffusion_xl_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_unet=True, has_text_encoder_1=True, @@ -163,24 +238,27 @@ def __setup_stable_diffusion_xl_ui(self): has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_wuerstchen_ui(self): + def __setup_wuerstchen_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_prior=True, allow_override_prior=self.train_config.model_type.is_stable_cascade(), has_text_encoder=True, ) - row = self.__create_effnet_encoder_components(row) - row = self.__create_decoder_components(row, self.train_config.model_type.is_wuerstchen_v2()) + row = self.__create_effnet_encoder_components(frame, row) + row = self.__create_decoder_components(frame, row, self.train_config.model_type.is_wuerstchen_v2()) row = self.__create_output_components( + frame, row, allow_safetensors=self.train_config.training_method != TrainingMethod.FINE_TUNE or self.train_config.model_type.is_stable_cascade(), @@ -188,42 +266,47 @@ def __setup_wuerstchen_ui(self): allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_pixart_alpha_ui(self): + def __setup_pixart_alpha_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_transformer=True, has_text_encoder=True, has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_sana_ui(self): + def __setup_sana_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_transformer=True, has_text_encoder=True, has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=self.train_config.training_method != TrainingMethod.FINE_TUNE, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_hunyuan_video_ui(self): + def __setup_hunyuan_video_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_transformer=True, has_text_encoder_1=True, @@ -231,16 +314,18 @@ def __setup_hunyuan_video_ui(self): has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_hi_dream_ui(self): + def __setup_hi_dream_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_transformer=True, has_text_encoder_1=True, @@ -251,6 +336,7 @@ def __setup_hi_dream_ui(self): has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, @@ -275,28 +361,28 @@ def __create_dtype_options(self, include_none:bool=True, include_gguf=False) -> return options - def __create_base_dtype_components(self, row: int) -> int: + def __create_base_dtype_components(self, frame, row: int) -> int: # huggingface token - components.label(self.scroll_frame, row, 0, "Hugging Face Token", + components.label(frame, row, 0, "Hugging Face Token", tooltip="Enter your Hugging Face access token if you have used a protected Hugging Face repository below.\nThis value is stored separately, not saved to your configuration file. " "Go to https://huggingface.co/settings/tokens to create an access token.", wide_tooltip=True) - components.entry(self.scroll_frame, row, 1, self.ui_state, "secrets.huggingface_token") + components.entry(frame, row, 1, self.ui_state, "secrets.huggingface_token") row += 1 # base model - components.label(self.scroll_frame, row, 0, "Base Model", + components.label(frame, row, 0, "Base Model", tooltip="Filename, directory or Hugging Face repository of the base model") components.file_entry( - self.scroll_frame, row, 1, self.ui_state, "base_model_name", + frame, row, 1, self.ui_state, "base_model_name", path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x ) # weight dtype - components.label(self.scroll_frame, row, 3, "Weight Data Type", + components.label(frame, row, 3, "Weight Data Type", tooltip="The base model weight data type used for training. This can reduce memory consumption, but reduces precision") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(False), + components.options_kv(frame, row, 4, self.__create_dtype_options(False), self.ui_state, "weight_dtype") row += 1 @@ -305,6 +391,7 @@ def __create_base_dtype_components(self, row: int) -> int: def __create_base_components( self, + frame, row: int, has_unet: bool = False, has_prior: bool = False, @@ -321,9 +408,9 @@ def __create_base_components( ) -> int: if has_unet: # unet weight dtype - components.label(self.scroll_frame, row, 3, "Override UNet Data Type", + components.label(frame, row, 3, "Override UNet Data Type", tooltip="Overrides the unet weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "unet.weight_dtype") row += 1 @@ -331,17 +418,17 @@ def __create_base_components( if has_prior: if allow_override_prior: # prior model - components.label(self.scroll_frame, row, 0, "Prior Model", + components.label(frame, row, 0, "Prior Model", tooltip="Filename, directory or Hugging Face repository of the prior model") components.file_entry( - self.scroll_frame, row, 1, self.ui_state, "prior.model_name", + frame, row, 1, self.ui_state, "prior.model_name", path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x ) # prior weight dtype - components.label(self.scroll_frame, row, 3, "Override Prior Data Type", + components.label(frame, row, 3, "Override Prior Data Type", tooltip="Overrides the prior weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "prior.weight_dtype") row += 1 @@ -349,53 +436,53 @@ def __create_base_components( if has_transformer: if allow_override_transformer: # transformer model - components.label(self.scroll_frame, row, 0, "Override Transformer / GGUF", + components.label(frame, row, 0, "Override Transformer / GGUF", tooltip="Can be used to override the transformer in the base model. Safetensors and GGUF files are supported, local and on Huggingface. If a GGUF file is used, the DataType must also be set to GGUF") components.file_entry( - self.scroll_frame, row, 1, self.ui_state, "transformer.model_name", + frame, row, 1, self.ui_state, "transformer.model_name", path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x ) # transformer weight dtype - components.label(self.scroll_frame, row, 3, "Override Transformer Data Type", + components.label(frame, row, 3, "Override Transformer Data Type", tooltip="Overrides the transformer weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(include_gguf=True), + components.options_kv(frame, row, 4, self.__create_dtype_options(include_gguf=True), self.ui_state, "transformer.weight_dtype") row += 1 if has_text_encoder: # text encoder weight dtype - components.label(self.scroll_frame, row, 3, "Override Text Encoder Data Type", + components.label(frame, row, 3, "Override Text Encoder Data Type", tooltip="Overrides the text encoder weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "text_encoder.weight_dtype") row += 1 if has_text_encoder_1: # text encoder 1 weight dtype - components.label(self.scroll_frame, row, 3, "Override Text Encoder 1 Data Type", + components.label(frame, row, 3, "Override Text Encoder 1 Data Type", tooltip="Overrides the text encoder 1 weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "text_encoder.weight_dtype") row += 1 if has_text_encoder_2: # text encoder 2 weight dtype - components.label(self.scroll_frame, row, 3, "Override Text Encoder 2 Data Type", + components.label(frame, row, 3, "Override Text Encoder 2 Data Type", tooltip="Overrides the text encoder 2 weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "text_encoder_2.weight_dtype") row += 1 if has_text_encoder_3: # text encoder 3 weight dtype - components.label(self.scroll_frame, row, 3, "Override Text Encoder 3 Data Type", + components.label(frame, row, 3, "Override Text Encoder 3 Data Type", tooltip="Overrides the text encoder 3 weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "text_encoder_3.weight_dtype") row += 1 @@ -403,53 +490,53 @@ def __create_base_components( if has_text_encoder_4: if allow_override_text_encoder_4: # text encoder 4 weight dtype - components.label(self.scroll_frame, row, 0, "Text Encoder 4 Override", + components.label(frame, row, 0, "Text Encoder 4 Override", tooltip="Filename, directory or Hugging Face repository of the text encoder 4 model") components.file_entry( - self.scroll_frame, row, 1, self.ui_state, "text_encoder_4.model_name", + frame, row, 1, self.ui_state, "text_encoder_4.model_name", path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x ) # text encoder 4 weight dtype - components.label(self.scroll_frame, row, 3, "Override Text Encoder 4 Data Type", + components.label(frame, row, 3, "Override Text Encoder 4 Data Type", tooltip="Overrides the text encoder 4 weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "text_encoder_4.weight_dtype") row += 1 if has_vae: # base model - components.label(self.scroll_frame, row, 0, "VAE Override", + components.label(frame, row, 0, "VAE Override", tooltip="Directory or Hugging Face repository of a VAE model in diffusers format. Can be used to override the VAE included in the base model. Using a safetensor VAE file will cause an error that the model cannot be loaded.") components.file_entry( - self.scroll_frame, row, 1, self.ui_state, "vae.model_name", + frame, row, 1, self.ui_state, "vae.model_name", path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x ) # vae weight dtype - components.label(self.scroll_frame, row, 3, "Override VAE Data Type", + components.label(frame, row, 3, "Override VAE Data Type", tooltip="Overrides the vae weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "vae.weight_dtype") row += 1 return row - def __create_effnet_encoder_components(self, row: int): + def __create_effnet_encoder_components(self, frame, row: int): # effnet encoder model - components.label(self.scroll_frame, row, 0, "Effnet Encoder Model", + components.label(frame, row, 0, "Effnet Encoder Model", tooltip="Filename, directory or Hugging Face repository of the effnet encoder model") components.file_entry( - self.scroll_frame, row, 1, self.ui_state, "effnet_encoder.model_name", + frame, row, 1, self.ui_state, "effnet_encoder.model_name", path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x ) # effnet encoder weight dtype - components.label(self.scroll_frame, row, 3, "Override Effnet Encoder Data Type", + components.label(frame, row, 3, "Override Effnet Encoder Data Type", tooltip="Overrides the effnet encoder weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "effnet_encoder.weight_dtype") row += 1 @@ -458,38 +545,39 @@ def __create_effnet_encoder_components(self, row: int): def __create_decoder_components( self, + frame, row: int, has_text_encoder: bool, ) -> int: # decoder model - components.label(self.scroll_frame, row, 0, "Decoder Model", + components.label(frame, row, 0, "Decoder Model", tooltip="Filename, directory or Hugging Face repository of the decoder model") components.file_entry( - self.scroll_frame, row, 1, self.ui_state, "decoder.model_name", + frame, row, 1, self.ui_state, "decoder.model_name", path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x ) # decoder weight dtype - components.label(self.scroll_frame, row, 3, "Override Decoder Data Type", + components.label(frame, row, 3, "Override Decoder Data Type", tooltip="Overrides the decoder weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "decoder.weight_dtype") row += 1 if has_text_encoder: # decoder text encoder weight dtype - components.label(self.scroll_frame, row, 3, "Override Decoder Text Encoder Data Type", + components.label(frame, row, 3, "Override Decoder Text Encoder Data Type", tooltip="Overrides the decoder text encoder weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "decoder_text_encoder.weight_dtype") row += 1 # decoder vqgan weight dtype - components.label(self.scroll_frame, row, 3, "Override Decoder VQGAN Data Type", + components.label(frame, row, 3, "Override Decoder VQGAN Data Type", tooltip="Overrides the decoder vqgan weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "decoder_vqgan.weight_dtype") row += 1 @@ -498,20 +586,21 @@ def __create_decoder_components( def __create_output_components( self, + frame, row: int, allow_safetensors: bool = False, allow_diffusers: bool = False, allow_legacy_safetensors: bool = False, ) -> int: # output model destination - components.label(self.scroll_frame, row, 0, "Model Output Destination", + components.label(frame, row, 0, "Model Output Destination", tooltip="Filename or directory where the output model is saved") - components.file_entry(self.scroll_frame, row, 1, self.ui_state, "output_model_destination", is_output=True) + components.file_entry(frame, row, 1, self.ui_state, "output_model_destination", is_output=True) # output data type - components.label(self.scroll_frame, row, 3, "Output Data Type", + components.label(frame, row, 3, "Output Data Type", tooltip="Precision to use when saving the output model") - components.options_kv(self.scroll_frame, row, 4, [ + components.options_kv(frame, row, 4, [ ("float16", DataType.FLOAT_16), ("float32", DataType.FLOAT_32), ("bfloat16", DataType.BFLOAT_16), @@ -530,17 +619,17 @@ def __create_output_components( # if allow_legacy_safetensors: # formats.append(("Legacy Safetensors", ModelFormat.LEGACY_SAFETENSORS)) - components.label(self.scroll_frame, row, 0, "Output Format", + components.label(frame, row, 0, "Output Format", tooltip="Format to use when saving the output model") - components.options_kv(self.scroll_frame, row, 1, formats, self.ui_state, "output_model_format") + components.options_kv(frame, row, 1, formats, self.ui_state, "output_model_format") # include config - components.label(self.scroll_frame, row, 3, "Include Config", + components.label(frame, row, 3, "Include Config", tooltip="Include the training configuration in the final model. Only supported for safetensors files. " "None: No config is included. " "Settings: All training settings are included. " "All: All settings, including the samples and concepts are included.") - components.options_kv(self.scroll_frame, row, 4, [ + components.options_kv(frame, row, 4, [ ("None", ConfigPart.NONE), ("Settings", ConfigPart.SETTINGS), ("All", ConfigPart.ALL), diff --git a/modules/ui/SampleWindow.py b/modules/ui/SampleWindow.py index a93b3593f..15a7dfee2 100644 --- a/modules/ui/SampleWindow.py +++ b/modules/ui/SampleWindow.py @@ -18,6 +18,7 @@ from modules.util.enum.EMAMode import EMAMode from modules.util.enum.FileType import FileType from modules.util.enum.TrainingMethod import TrainingMethod +from modules.util.ModuleFilter import ModuleFilter from modules.util.time_util import get_string_timestamp from modules.util.ui import components from modules.util.ui.ui_utils import set_window_icon @@ -124,10 +125,15 @@ def __load_model(self) -> BaseModel: else: print("No backup found, loading without backup...") + quant_filters = [ + ModuleFilter(pattern, use_regex=self.initial_train_config.quantization_layer_filter_regex) + for pattern in self.initial_train_config.quantization_layer_filter.split(",") + ] model = model_loader.load( model_type=self.initial_train_config.model_type, model_names=model_names, weight_dtypes=self.initial_train_config.weight_dtypes(), + quant_filters=quant_filters, ) model.train_config = self.initial_train_config diff --git a/modules/util/config/TrainConfig.py b/modules/util/config/TrainConfig.py index 80f3b3f92..c4f925fc7 100644 --- a/modules/util/config/TrainConfig.py +++ b/modules/util/config/TrainConfig.py @@ -1056,6 +1056,11 @@ def default_values() -> 'TrainConfig': data.append(("layer_filter_preset", "full", str, False)) data.append(("layer_filter_regex", False, bool, False)) + #quantization layer filter + data.append(("quantization_layer_filter", "", str, False)) + data.append(("quantization_layer_filter_preset", "full", str, False)) + data.append(("quantization_layer_filter_regex", False, bool, False)) + # embedding data.append(("embedding_learning_rate", None, float, True)) data.append(("preserve_embedding_norm", False, bool, False)) diff --git a/modules/util/quantization_util.py b/modules/util/quantization_util.py index 3a3200e23..b69907baf 100644 --- a/modules/util/quantization_util.py +++ b/modules/util/quantization_util.py @@ -4,6 +4,7 @@ from modules.module.quantized.mixin.QuantizedLinearMixin import QuantizedLinearMixin from modules.module.quantized.mixin.QuantizedModuleMixin import QuantizedModuleMixin from modules.util.enum.DataType import DataType +from modules.util.ModuleFilter import ModuleFilter import torch from torch import Tensor, nn @@ -75,10 +76,13 @@ def __replace_linear_layers_recursive( parent_module: nn.Module, convert_fn: Callable[[nn.Linear, bool], nn.Module], keep_in_fp32_modules: list[str] | None = None, + filters: list[ModuleFilter] | None = None, copy_parameters: bool = False, name_prefix: str = "", visited_modules: set[int] | None = None, ): + #both 'keep_in_fp32_modules' and 'filters' are layer filters: keep_in_fp32_modules is set by diffusers, 'filters' is set by the user. + #Apply both. 'keep_in_fp32_modules' only looks at attr_name, 'filters' looks at the entire key at the leafs: if keep_in_fp32_modules is None: keep_in_fp32_modules = [] @@ -87,9 +91,13 @@ def __replace_linear_layers_recursive( visited_modules = set() visited_modules.add(id(parent_module)) + if isinstance(parent_module, (nn.ModuleList, nn.Sequential)): for i, module in enumerate(parent_module): if isinstance(module, nn.Linear): + if filters is not None and len(filters) > 0 and not any(f.matches(name_prefix) for f in filters): + continue + quant_linear = convert_fn(module, copy_parameters) parent_module[i] = quant_linear del module @@ -98,6 +106,7 @@ def __replace_linear_layers_recursive( parent_module=module, convert_fn=convert_fn, keep_in_fp32_modules=keep_in_fp32_modules, + filters=filters, copy_parameters=copy_parameters, name_prefix=f"{name_prefix}[{i}]", visited_modules=visited_modules, @@ -109,6 +118,10 @@ def __replace_linear_layers_recursive( module = getattr(parent_module, attr_name) if isinstance(module, nn.Linear): + key_name = attr_name if name_prefix == "" else f"{name_prefix}.{attr_name}" + if filters is not None and len(filters) > 0 and not any(f.matches(key_name) for f in filters): + continue + quant_linear = convert_fn(module, copy_parameters) setattr(parent_module, attr_name, quant_linear) del module @@ -117,8 +130,9 @@ def __replace_linear_layers_recursive( parent_module=module, convert_fn=convert_fn, keep_in_fp32_modules=keep_in_fp32_modules, + filters=filters, copy_parameters=copy_parameters, - name_prefix=f"{name_prefix}.{attr_name}", + name_prefix=attr_name if name_prefix == "" else f"{name_prefix}.{attr_name}", visited_modules=visited_modules, ) @@ -126,9 +140,10 @@ def __replace_linear_layers( parent_module: nn.Module, convert_fn: Callable[[nn.Linear, bool], nn.Module], keep_in_fp32_modules: list[str] | None = None, + filters: list[ModuleFilter] | None = None, copy_parameters: bool = False, ): - __replace_linear_layers_recursive(parent_module, convert_fn, keep_in_fp32_modules, copy_parameters) + __replace_linear_layers_recursive(parent_module, convert_fn, keep_in_fp32_modules, filters, copy_parameters) #ensure that all Linear layers were replaced #https://github.com/Nerogar/OneTrainer/issues/1050 @@ -136,17 +151,20 @@ def __replace_linear_layers( assert (not isinstance(module, nn.Linear) or isinstance(module, QuantizedLinearMixin) or any(s in name.split('.') for s in keep_in_fp32_modules) + or (filters is not None and len(filters) > 0 and not any(f.matches(name) for f in filters)) ), f"Linear layer {name} was not found in model for quantization" def replace_linear_with_nf4_layers( parent_module: nn.Module, keep_in_fp32_modules: list[str] | None = None, + filters: list[ModuleFilter] | None = None, copy_parameters: bool = False, ): __replace_linear_layers( parent_module=parent_module, convert_fn=__create_nf4_linear_layer, keep_in_fp32_modules=keep_in_fp32_modules, + filters=filters, copy_parameters=copy_parameters, ) @@ -154,12 +172,14 @@ def replace_linear_with_nf4_layers( def replace_linear_with_int8_layers( parent_module: nn.Module, keep_in_fp32_modules: list[str] | None = None, + filters: list[ModuleFilter] | None = None, copy_parameters: bool = False, ): __replace_linear_layers( parent_module=parent_module, convert_fn=__create_int8_linear_layer, keep_in_fp32_modules=keep_in_fp32_modules, + filters=filters, copy_parameters=copy_parameters, ) @@ -167,12 +187,14 @@ def replace_linear_with_int8_layers( def replace_linear_with_fp8_layers( parent_module: nn.Module, keep_in_fp32_modules: list[str] | None = None, + filters: list[ModuleFilter] | None = None, copy_parameters: bool = False, ): __replace_linear_layers( parent_module=parent_module, convert_fn=__create_fp8_linear_layer, keep_in_fp32_modules=keep_in_fp32_modules, + filters=filters, copy_parameters=copy_parameters, ) From e5d031799528ef0c1d1463c94d2f7b1be7196487 Mon Sep 17 00:00:00 2001 From: dxqb Date: Sun, 2 Nov 2025 15:09:07 +0100 Subject: [PATCH 30/54] add blocks preset --- modules/modelSetup/BaseHiDreamSetup.py | 1 + modules/modelSetup/BaseHunyuanVideoSetup.py | 1 + modules/modelSetup/BasePixArtAlphaSetup.py | 1 + modules/modelSetup/BaseSanaSetup.py | 1 + modules/modelSetup/BaseStableDiffusion3Setup.py | 1 + 5 files changed, 5 insertions(+) diff --git a/modules/modelSetup/BaseHiDreamSetup.py b/modules/modelSetup/BaseHiDreamSetup.py index f98ecec43..ade011541 100644 --- a/modules/modelSetup/BaseHiDreamSetup.py +++ b/modules/modelSetup/BaseHiDreamSetup.py @@ -27,6 +27,7 @@ PRESETS = { "attn-mlp": ["attn1", "ff_i"], "attn-only": ["attn1"], + "blocks": ["stream_block"], "full": [], } diff --git a/modules/modelSetup/BaseHunyuanVideoSetup.py b/modules/modelSetup/BaseHunyuanVideoSetup.py index 70db6806a..437d262ec 100644 --- a/modules/modelSetup/BaseHunyuanVideoSetup.py +++ b/modules/modelSetup/BaseHunyuanVideoSetup.py @@ -28,6 +28,7 @@ PRESETS = { "attn-mlp": ["attn", "ff.net"], "attn-only": ["attn"], + "blocks": ["transformer_block"], "full": [], } diff --git a/modules/modelSetup/BasePixArtAlphaSetup.py b/modules/modelSetup/BasePixArtAlphaSetup.py index 754a1fd5c..dfb84c496 100644 --- a/modules/modelSetup/BasePixArtAlphaSetup.py +++ b/modules/modelSetup/BasePixArtAlphaSetup.py @@ -27,6 +27,7 @@ PRESETS = { "attn-mlp": ["attn1", "attn2", "ff.net"], "attn-only": ["attn1", "attn2"], + "blocks": ["transformer_block"], "full": [], } diff --git a/modules/modelSetup/BaseSanaSetup.py b/modules/modelSetup/BaseSanaSetup.py index ee58332d3..a5cd50fa5 100644 --- a/modules/modelSetup/BaseSanaSetup.py +++ b/modules/modelSetup/BaseSanaSetup.py @@ -27,6 +27,7 @@ PRESETS = { "attn-mlp": ["attn1", "attn2", "ff."], "attn-only": ["attn1", "attn2"], + "blocks": ["transformer_block"], "full": [], } diff --git a/modules/modelSetup/BaseStableDiffusion3Setup.py b/modules/modelSetup/BaseStableDiffusion3Setup.py index 4d20536cf..48d08faf8 100644 --- a/modules/modelSetup/BaseStableDiffusion3Setup.py +++ b/modules/modelSetup/BaseStableDiffusion3Setup.py @@ -27,6 +27,7 @@ PRESETS = { "attn-only": ["attn"], + "blocks": ["transformer_block"], "full": [], } From 0242d16a683042e0fc1f3e1012957408f1489f42 Mon Sep 17 00:00:00 2001 From: dxqb Date: Sun, 2 Nov 2025 16:07:02 +0100 Subject: [PATCH 31/54] quantization filter in presets --- training_presets/#chroma LoRA 16GB.json | 4 +++- training_presets/#chroma LoRA 24GB.json | 4 +++- training_presets/#chroma LoRA 8GB.json | 4 +++- training_presets/#qwen LoRA 16GB.json | 4 +++- training_presets/#qwen LoRA 24GB.json | 4 +++- 5 files changed, 15 insertions(+), 5 deletions(-) diff --git a/training_presets/#chroma LoRA 16GB.json b/training_presets/#chroma LoRA 16GB.json index a99eecca4..0cbe71627 100644 --- a/training_presets/#chroma LoRA 16GB.json +++ b/training_presets/#chroma LoRA 16GB.json @@ -22,5 +22,7 @@ "timestep_distribution": "INVERTED_PARABOLA", "noising_weight": 7.7, "layer_filter": "attn,ff.net", - "layer_filter_preset": "attn-mlp" + "layer_filter_preset": "attn-mlp", + "quantization_layer_filter": "transformer_block", + "quantization_layer_filter_preset": "blocks" } diff --git a/training_presets/#chroma LoRA 24GB.json b/training_presets/#chroma LoRA 24GB.json index 5877009c6..76868401b 100644 --- a/training_presets/#chroma LoRA 24GB.json +++ b/training_presets/#chroma LoRA 24GB.json @@ -22,5 +22,7 @@ "timestep_distribution": "INVERTED_PARABOLA", "noising_weight": 7.7, "layer_filter": "attn,ff.net", - "layer_filter_preset": "attn-mlp" + "layer_filter_preset": "attn-mlp", + "quantization_layer_filter": "transformer_block", + "quantization_layer_filter_preset": "blocks" } diff --git a/training_presets/#chroma LoRA 8GB.json b/training_presets/#chroma LoRA 8GB.json index 6fa19670c..e8a5e556b 100644 --- a/training_presets/#chroma LoRA 8GB.json +++ b/training_presets/#chroma LoRA 8GB.json @@ -25,5 +25,7 @@ "timestep_distribution": "INVERTED_PARABOLA", "noising_weight": 7.7, "layer_filter": "attn,ff.net", - "layer_filter_preset": "attn-mlp" + "layer_filter_preset": "attn-mlp", + "quantization_layer_filter": "transformer_block", + "quantization_layer_filter_preset": "blocks" } diff --git a/training_presets/#qwen LoRA 16GB.json b/training_presets/#qwen LoRA 16GB.json index a101c788c..7d877b9fc 100644 --- a/training_presets/#qwen LoRA 16GB.json +++ b/training_presets/#qwen LoRA 16GB.json @@ -24,5 +24,7 @@ "output_dtype": "BFLOAT_16", "timestep_distribution": "LOGIT_NORMAL", "layer_filter": "attn,img_mlp,txt_mlp", - "layer_filter_preset": "attn-mlp" + "layer_filter_preset": "attn-mlp", + "quantization_layer_filter": "transformer_block", + "quantization_layer_filter_preset": "blocks" } diff --git a/training_presets/#qwen LoRA 24GB.json b/training_presets/#qwen LoRA 24GB.json index bd03b0cc2..66bd7324d 100644 --- a/training_presets/#qwen LoRA 24GB.json +++ b/training_presets/#qwen LoRA 24GB.json @@ -24,5 +24,7 @@ "output_dtype": "BFLOAT_16", "timestep_distribution": "LOGIT_NORMAL", "layer_filter": "attn,img_mlp,txt_mlp", - "layer_filter_preset": "attn-mlp" + "layer_filter_preset": "attn-mlp", + "quantization_layer_filter": "transformer_block", + "quantization_layer_filter_preset": "blocks" } From 9e897b8160fc58f44263ec18cfcc5a149f686aa0 Mon Sep 17 00:00:00 2001 From: dxqb Date: Sun, 2 Nov 2025 17:54:55 +0100 Subject: [PATCH 32/54] #1054 --- modules/trainer/GenericTrainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/trainer/GenericTrainer.py b/modules/trainer/GenericTrainer.py index 7dfa76398..94880e0fd 100644 --- a/modules/trainer/GenericTrainer.py +++ b/modules/trainer/GenericTrainer.py @@ -157,7 +157,7 @@ def start(self): def __save_config_to_workspace(self): path = path_util.canonical_join(self.config.workspace_dir, "config") os.makedirs(Path(path).absolute(), exist_ok=True) - path = path_util.canonical_join(path, f"{get_string_timestamp()}.json") + path = path_util.canonical_join(path, f"{self.config.save_filename_prefix}{get_string_timestamp()}.json") with open(path, "w") as f: json.dump(self.config.to_pack_dict(secrets=False), f, indent=4) From 3448801d517223b3e66cd4e89721bd2873e84543 Mon Sep 17 00:00:00 2001 From: dxqb Date: Sun, 2 Nov 2025 19:50:28 +0100 Subject: [PATCH 33/54] bugfix --- modules/util/ui/components.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/util/ui/components.py b/modules/util/ui/components.py index 1c541a606..4157551ab 100644 --- a/modules/util/ui/components.py +++ b/modules/util/ui/components.py @@ -357,8 +357,8 @@ def preset_set_layer_choice(selected: str): layer_entry.configure(state="disabled", fg_color=disabled_color, text_color=disabled_text_color) layer_entry.cget('textvariable').set(",".join(patterns)) - ui_state.get_var("layer_filter").set(",".join(patterns)) - ui_state.get_var("layer_filter_regex").set(preset_uses_regex) + ui_state.get_var(entry_var_name).set(",".join(patterns)) + ui_state.get_var(regex_var_name).set(preset_uses_regex) regex_label.grid_remove() regex_switch.grid_remove() From 5bc6c5abeb36d185c3d0391c1ea2b48bd253282c Mon Sep 17 00:00:00 2001 From: dxqb Date: Mon, 3 Nov 2025 22:57:57 +0100 Subject: [PATCH 34/54] smaller eps, because gradients for some models are close to 1e-12 --- modules/module/quantized/LinearW8A8.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index 49ed6607d..c51489ee2 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -13,7 +13,7 @@ def quantize_int8(x: Tensor, scale: float | Tensor) -> Tensor: def quantize_int8_tensorwise_get_scale(x: Tensor) -> float: abs_max = x.abs().max() - scale = (abs_max.float() / 127.0).clamp(min=1e-12) + scale = (abs_max.float() / 127.0).clamp(min=1e-30) return scale def quantize_int8_tensorwise(x: Tensor) -> tuple[Tensor, float]: @@ -23,7 +23,7 @@ def quantize_int8_tensorwise(x: Tensor) -> tuple[Tensor, float]: def quantize_int8_axiswise_get_scale(x: Tensor, dim: int) -> Tensor: abs_max = x.abs().amax(dim=dim, keepdim=True) - scale = (abs_max.float() / 127.0).clamp(min=1e-12) + scale = (abs_max.float() / 127.0).clamp(min=1e-30) return scale def quantize_int8_axiswise(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: @@ -37,12 +37,12 @@ def quantize_fp8(x: Tensor, scale: float | Tensor) -> Tensor: def quantize_fp8_tensorwise_get_scale(x: Tensor) -> float: abs_max = x.abs().max() - scale = (abs_max.float() / 448.0).clamp(min=1e-12) + scale = (abs_max.float() / 448.0).clamp(min=1e-30) return scale def quantize_fp8_axiswise_get_scale(x: Tensor, dim: int) -> Tensor: abs_max = x.abs().amax(dim=dim, keepdim=True) - scale = (abs_max.float() / 448.0).clamp(min=1e-12) + scale = (abs_max.float() / 448.0).clamp(min=1e-30) return scale def quantize_fp8_tensorwise(x: Tensor) -> tuple[Tensor, float]: From fb1e8a8208b8715980e4cdcca476a9eddf861ce7 Mon Sep 17 00:00:00 2001 From: dxqb Date: Tue, 4 Nov 2025 19:03:33 +0100 Subject: [PATCH 35/54] compile benchmarks --- modules/module/quantized/LinearW8A8.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index c51489ee2..1b92507b2 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -178,7 +178,9 @@ def forward(self, x_orig: torch.Tensor) -> torch.Tensor: assert y.dtype == self._compute_dtype return y.reshape(x_orig.shape[:-1] + (y.shape[-1], )) -def run_benchmark(fn, desc, steps=10000, warmup=500): +def run_benchmark(fn, desc, steps=10000, warmup=500, compile=False): + if compile: + fn = torch.compile(fn, fullgraph=True) from tqdm import tqdm for _ in range(warmup): fn() @@ -205,8 +207,8 @@ def torch_backward(a, b): run_benchmark(lambda: torch_backward(y_8, w_8), "torch mm backward int8") run_benchmark(lambda: triton_mm_8bit(y_8, w_8), "triton mm backward int8") - run_benchmark(lambda: int8_forward_tokenwise(x, w_8, w_scale), "torch forward int") - run_benchmark(lambda: int8_backward_axiswise(y, w_8, w_scale), "triton backward int") + run_benchmark(lambda: int8_forward_tokenwise(x, w_8, w_scale), "torch forward int", compile=True) + run_benchmark(lambda: int8_backward_axiswise(y, w_8, w_scale), "triton backward int", compile=True) @torch.no_grad() @@ -225,8 +227,8 @@ def torch_backward(a, b): torch._scaled_mm(a, b.T.contiguous().T, out_dtype=torch.bfloat16, scale_a=one_scale.float(), scale_b=w_scale.float()) run_benchmark(lambda: torch_backward(y_8, w_8), "torch mm backward fp8") run_benchmark(lambda: triton_mm_8bit(y_8, w_8), "triton mm backward fp8") - run_benchmark(lambda: fp8_forward_tokenwise(x, w_8, w_scale), "torch forward fp8") - run_benchmark(lambda: fp8_backward_axiswise(y, w_8, w_scale), "triton backward fp8") + run_benchmark(lambda: fp8_forward_tokenwise(x, w_8, w_scale), "torch forward fp8", compile=True) + run_benchmark(lambda: fp8_backward_axiswise(y, w_8, w_scale), "triton backward fp8", compile=True) if __name__ == "__main__": From 4ca84db4a1635de7413e679c0c1b3429499f3209 Mon Sep 17 00:00:00 2001 From: dxqb Date: Tue, 4 Nov 2025 19:48:05 +0100 Subject: [PATCH 36/54] remove cast --- modules/module/quantized/LinearW8A8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index 1b92507b2..9d2056630 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -107,7 +107,7 @@ class LinearFp8Function(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, weight: Tensor, weight_scale: float, bias: Tensor | None) -> Tensor: ctx.save_for_backward(weight, weight_scale) - return fp8_forward_tokenwise(x.bfloat16(), weight, weight_scale, bias).bfloat16() + return fp8_forward_tokenwise(x, weight, weight_scale, bias) @staticmethod def backward(ctx, x: Tensor): From 41e44a27d43f1c2c6e8c798d16b9d4365056e427 Mon Sep 17 00:00:00 2001 From: dxqb Date: Tue, 4 Nov 2025 19:55:39 +0100 Subject: [PATCH 37/54] detach dequantized weights --- modules/module/quantized/LinearGGUFA8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/module/quantized/LinearGGUFA8.py b/modules/module/quantized/LinearGGUFA8.py index 1898c2915..a479f8d0b 100644 --- a/modules/module/quantized/LinearGGUFA8.py +++ b/modules/module/quantized/LinearGGUFA8.py @@ -83,7 +83,7 @@ def __init__(self, dtype: torch.dtype, *args, **kwargs): def forward(self, x_orig: torch.Tensor) -> torch.Tensor: assert not self.weight.requires_grad x = x_orig.to(self.compute_dtype).reshape(-1, x_orig.shape[-1]) - w = dequantize_gguf_tensor(self.weight) + w = dequantize_gguf_tensor(self.weight).detach() if x.shape[0] > 16 and self.weight.quant_type not in UNQUANTIZED_TYPES: if self._dtype == torch.int8: From b3f69aef6d4b6749457f2dfedd2e34980475250a Mon Sep 17 00:00:00 2001 From: dxqb Date: Fri, 7 Nov 2025 11:55:04 +0100 Subject: [PATCH 38/54] name changes --- modules/module/quantized/LinearGGUFA8.py | 32 +++++++++++------------ modules/module/quantized/LinearW8A8.py | 33 +++++++++++------------- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/modules/module/quantized/LinearGGUFA8.py b/modules/module/quantized/LinearGGUFA8.py index a479f8d0b..24a5a73d9 100644 --- a/modules/module/quantized/LinearGGUFA8.py +++ b/modules/module/quantized/LinearGGUFA8.py @@ -14,7 +14,7 @@ UNQUANTIZED_TYPES = [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16] -def int8_forward_both_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> Tensor: +def int8_forward_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> Tensor: x_8, x_scale = quantize_int8_axiswise(x, dim=-1) w_8, w_scale = quantize_int8_axiswise(weight, dim=-1) res = torch._int_mm(x_8, w_8.T) @@ -23,7 +23,7 @@ def int8_forward_both_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> res_scaled.add_(bias.to(x.dtype)) return res_scaled -def fp8_forward_both_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> Tensor: +def fp8_forward_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> Tensor: x_8, x_scale = quantize_fp8_axiswise(x, dim=-1) w_8, w_scale = quantize_fp8_axiswise(weight, dim=-1) one = torch.ones(1, device=x.device) @@ -33,17 +33,17 @@ def fp8_forward_both_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> T res_scaled.add_(bias.to(x.dtype)) return res_scaled -def int8_backward_both_axiswise(x: Tensor, weight: Tensor) -> Tensor: - x_8, x_scale = quantize_int8_axiswise(x, dim=-1) +def int8_backward_axiswise(output: Tensor, weight: Tensor) -> Tensor: + output_8, output_scale = quantize_int8_axiswise(output, dim=-1) w_8, w_scale = quantize_int8_axiswise(weight, dim=0) - mm_res = triton_mm_8bit(x_8, w_8) - return mm_res.to(x.dtype).mul_(w_scale).mul_(x_scale) + mm_res = triton_mm_8bit(output_8, w_8) + return mm_res.to(output.dtype).mul_(w_scale).mul_(output_scale) -def fp8_backward_both_axiswise(x: Tensor, weight: Tensor) -> Tensor: - x_8, x_scale = quantize_fp8_axiswise(x, dim=-1) +def fp8_backward_axiswise(output: Tensor, weight: Tensor) -> Tensor: + output_8, output_scale = quantize_fp8_axiswise(output, dim=-1) w_8, w_scale = quantize_fp8_axiswise(weight, dim=0) - mm_res = triton_mm_8bit(x_8, w_8) - return mm_res.to(x.dtype).mul_(w_scale).mul_(x_scale) + mm_res = triton_mm_8bit(output_8, w_8) + return mm_res.to(output.dtype).mul_(w_scale).mul_(output_scale) class LinearGGUFIntA8RequantFunction(torch.autograd.Function): @staticmethod @@ -51,27 +51,27 @@ def forward(ctx, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: ctx.save_for_backward(weight) #axiswise performs better than tensorwise in tests, even though #it requires another requant during backward - but requant is cheap - return int8_forward_both_axiswise(x, weight, bias) + return int8_forward_axiswise(x, weight, bias) @staticmethod - def backward(ctx, x: Tensor): + def backward(ctx, output: Tensor): if ctx.needs_input_grad != (True, False, False): raise NotImplementedError("GGUF cannot be used for full finetuning") weight, = ctx.saved_tensors - return int8_backward_both_axiswise(x, weight), None, None + return int8_backward_axiswise(output, weight), None, None class LinearGGUFFpA8RequantFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: ctx.save_for_backward(weight) - return fp8_forward_both_axiswise(x, weight, bias) + return fp8_forward_axiswise(x, weight, bias) @staticmethod - def backward(ctx, x: Tensor): + def backward(ctx, output: Tensor): if ctx.needs_input_grad != (True, False, False): raise NotImplementedError("GGUF cannot be used for full finetuning") weight, = ctx.saved_tensors - return fp8_backward_both_axiswise(x, weight), None, None + return fp8_backward_axiswise(output, weight), None, None class LinearGGUFA8(GGUFLinear): def __init__(self, dtype: torch.dtype, *args, **kwargs): diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index 9d2056630..f6a5a7964 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -55,7 +55,7 @@ def quantize_fp8_axiswise(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: q = quantize_fp8(x, scale) return q, scale -def unquantize(q: Tensor, scale: float | Tensor, compute_dtype: torch.dtype) -> Tensor: +def dequantize(q: Tensor, scale: float | Tensor, compute_dtype: torch.dtype) -> Tensor: return q.to(compute_dtype) * scale.to(compute_dtype) def int8_forward_tokenwise(x: Tensor, weight: Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: @@ -71,22 +71,19 @@ def fp8_forward_tokenwise(x: Tensor, weight: Tensor, weight_scale: float, bias: one = torch.ones(1, device=x.device) res = torch._scaled_mm(x_8, weight.T, scale_a=one, scale_b=weight_scale.float(), out_dtype=x.dtype) res_scaled = res.mul_(x_scale) #much faster than scaled by _scaled_mm - if bias is not None: res_scaled.add_(bias.to(x.dtype)) return res_scaled +def int8_backward_axiswise(output: Tensor, weight: Tensor, weight_scale: float) -> Tensor: + output_8, output_scale = quantize_int8_axiswise(output, dim=-1) + mm_res = triton_mm_8bit(output_8, weight) + return mm_res.to(output.dtype).mul_(weight_scale * output_scale) -def int8_backward_axiswise(x: Tensor, weight: Tensor, weight_scale: float) -> Tensor: - x_8, x_scale = quantize_int8_axiswise(x, dim=-1) - mm_res = triton_mm_8bit(x_8, weight) - return mm_res.to(x.dtype).mul_(weight_scale * x_scale) - -def fp8_backward_axiswise(x: Tensor, weight: Tensor, weight_scale: float) -> Tensor: - x_8, x_scale = quantize_fp8_axiswise(x, dim=-1) - mm_res = triton_mm_8bit(x_8, weight) - return mm_res.to(x.dtype).mul_(weight_scale * x_scale) - +def fp8_backward_axiswise(output: Tensor, weight: Tensor, weight_scale: float) -> Tensor: + output_8, output_scale = quantize_fp8_axiswise(output, dim=-1) + mm_res = triton_mm_8bit(output_8, weight) + return mm_res.to(output.dtype).mul_(weight_scale * output_scale) class LinearInt8Function(torch.autograd.Function): @@ -96,12 +93,12 @@ def forward(ctx, x: Tensor, weight: Tensor, weight_scale: float, bias: Tensor | return int8_forward_tokenwise(x, weight, weight_scale, bias) @staticmethod - def backward(ctx, x: Tensor): + def backward(ctx, output: Tensor): if ctx.needs_input_grad != (True, False, False, False): raise NotImplementedError("Int A8W8 cannot be used for full finetuning") weight, weight_scale = ctx.saved_tensors - return int8_backward_axiswise(x, weight, weight_scale), None, None, None + return int8_backward_axiswise(output, weight, weight_scale), None, None, None class LinearFp8Function(torch.autograd.Function): @staticmethod @@ -110,12 +107,12 @@ def forward(ctx, x: Tensor, weight: Tensor, weight_scale: float, bias: Tensor | return fp8_forward_tokenwise(x, weight, weight_scale, bias) @staticmethod - def backward(ctx, x: Tensor): + def backward(ctx, output: Tensor): if ctx.needs_input_grad != (True, False, False, False): raise NotImplementedError("Float A8W8 cannot be used for full finetuning") weight, weight_scale = ctx.saved_tensors - return fp8_backward_axiswise(x, weight, weight_scale), None, None, None + return fp8_backward_axiswise(output, weight, weight_scale), None, None, None class LinearW8A8( nn.Linear, @@ -136,7 +133,7 @@ def original_weight_shape(self) -> tuple[int, ...]: return self.weight.shape def unquantized_weight(self, dtype: torch.dtype, device: torch.device) -> torch.Tensor: - return unquantize(self.weight.detach(), self.scale, self._compute_dtype).to(dtype) + return dequantize(self.weight.detach(), self.scale, self._compute_dtype).to(dtype) @torch.no_grad() def quantize(self, device: torch.device | None = None, **kwargs): @@ -172,7 +169,7 @@ def forward(self, x_orig: torch.Tensor) -> torch.Tensor: else: y = LinearFp8Function.apply(x, self.weight, self.scale, self.bias) else: - w = unquantize(self.weight, self.scale, compute_dtype=self._compute_dtype) + w = dequantize(self.weight, self.scale, compute_dtype=self._compute_dtype) y = torch.nn.functional.linear(x, w, self.bias.to(self._compute_dtype)) assert y.dtype == self._compute_dtype From f9c12a8f74e8c076bee2e2b98b40150bf8fe7635 Mon Sep 17 00:00:00 2001 From: dxqb Date: Fri, 7 Nov 2025 12:03:44 +0100 Subject: [PATCH 39/54] move code --- modules/module/quantized/LinearGGUFA8.py | 2 +- modules/module/quantized/LinearW8A8.py | 58 +++--------------------- modules/util/quantization_util.py | 51 +++++++++++++++++++++ 3 files changed, 59 insertions(+), 52 deletions(-) diff --git a/modules/module/quantized/LinearGGUFA8.py b/modules/module/quantized/LinearGGUFA8.py index 24a5a73d9..90c3db79f 100644 --- a/modules/module/quantized/LinearGGUFA8.py +++ b/modules/module/quantized/LinearGGUFA8.py @@ -1,4 +1,4 @@ -from modules.module.quantized.LinearW8A8 import ( +from modules.util.quantization_util import ( quantize_fp8_axiswise, quantize_int8_axiswise, ) diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index f6a5a7964..d9f056bb6 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -1,63 +1,19 @@ from modules.module.quantized.mixin.QuantizedLinearMixin import QuantizedLinearMixin from modules.module.quantized.mixin.QuantizedModuleMixin import QuantizedModuleMixin +from modules.util.quantization_util import ( + dequantize, + quantize_fp8_axiswise, + quantize_fp8_tensorwise, + quantize_int8_axiswise, + quantize_int8_tensorwise, +) from modules.util.triton_mm_8bit import mm_8bit as triton_mm_8bit import torch from torch import Tensor, nn -def quantize_int8(x: Tensor, scale: float | Tensor) -> Tensor: - q = x.float().mul(1.0 / scale).round_().clamp_(-128.0, 127.0).to(torch.int8) - return q - -def quantize_int8_tensorwise_get_scale(x: Tensor) -> float: - abs_max = x.abs().max() - scale = (abs_max.float() / 127.0).clamp(min=1e-30) - return scale - -def quantize_int8_tensorwise(x: Tensor) -> tuple[Tensor, float]: - scale = quantize_int8_tensorwise_get_scale(x) - q = quantize_int8(x, scale) - return q, scale - -def quantize_int8_axiswise_get_scale(x: Tensor, dim: int) -> Tensor: - abs_max = x.abs().amax(dim=dim, keepdim=True) - scale = (abs_max.float() / 127.0).clamp(min=1e-30) - return scale - -def quantize_int8_axiswise(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: - scale = quantize_int8_axiswise_get_scale(x, dim) - q = quantize_int8(x, scale) - return q, scale - -def quantize_fp8(x: Tensor, scale: float | Tensor) -> Tensor: - q = x.float().mul(1.0 / scale).clamp_(-448.0, 448.0).to(torch.float8_e4m3fn) - return q - -def quantize_fp8_tensorwise_get_scale(x: Tensor) -> float: - abs_max = x.abs().max() - scale = (abs_max.float() / 448.0).clamp(min=1e-30) - return scale - -def quantize_fp8_axiswise_get_scale(x: Tensor, dim: int) -> Tensor: - abs_max = x.abs().amax(dim=dim, keepdim=True) - scale = (abs_max.float() / 448.0).clamp(min=1e-30) - return scale - -def quantize_fp8_tensorwise(x: Tensor) -> tuple[Tensor, float]: - scale = quantize_fp8_tensorwise_get_scale(x) - q = quantize_fp8(x, scale) - return q, scale - -def quantize_fp8_axiswise(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: - scale = quantize_fp8_axiswise_get_scale(x, dim) - q = quantize_fp8(x, scale) - return q, scale - -def dequantize(q: Tensor, scale: float | Tensor, compute_dtype: torch.dtype) -> Tensor: - return q.to(compute_dtype) * scale.to(compute_dtype) - def int8_forward_tokenwise(x: Tensor, weight: Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: x_8, x_scale = quantize_int8_axiswise(x, dim=-1) res = torch._int_mm(x_8, weight.T) diff --git a/modules/util/quantization_util.py b/modules/util/quantization_util.py index c2c324f1e..3b7b85baa 100644 --- a/modules/util/quantization_util.py +++ b/modules/util/quantization_util.py @@ -242,3 +242,54 @@ def offload_quantized( new_tensor = allocator(tensor) new_tensor.copy_(tensor.data, non_blocking=non_blocking) tensor.data = new_tensor + +def quantize_int8(x: Tensor, scale: float | Tensor) -> Tensor: + q = x.float().mul(1.0 / scale).round_().clamp_(-128.0, 127.0).to(torch.int8) + return q + +def quantize_int8_tensorwise_get_scale(x: Tensor) -> float: + abs_max = x.abs().max() + scale = (abs_max.float() / 127.0).clamp(min=1e-30) + return scale + +def quantize_int8_tensorwise(x: Tensor) -> tuple[Tensor, float]: + scale = quantize_int8_tensorwise_get_scale(x) + q = quantize_int8(x, scale) + return q, scale + +def quantize_int8_axiswise_get_scale(x: Tensor, dim: int) -> Tensor: + abs_max = x.abs().amax(dim=dim, keepdim=True) + scale = (abs_max.float() / 127.0).clamp(min=1e-30) + return scale + +def quantize_int8_axiswise(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: + scale = quantize_int8_axiswise_get_scale(x, dim) + q = quantize_int8(x, scale) + return q, scale + +def quantize_fp8(x: Tensor, scale: float | Tensor) -> Tensor: + q = x.float().mul(1.0 / scale).clamp_(-448.0, 448.0).to(torch.float8_e4m3fn) + return q + +def quantize_fp8_tensorwise_get_scale(x: Tensor) -> float: + abs_max = x.abs().max() + scale = (abs_max.float() / 448.0).clamp(min=1e-30) + return scale + +def quantize_fp8_axiswise_get_scale(x: Tensor, dim: int) -> Tensor: + abs_max = x.abs().amax(dim=dim, keepdim=True) + scale = (abs_max.float() / 448.0).clamp(min=1e-30) + return scale + +def quantize_fp8_tensorwise(x: Tensor) -> tuple[Tensor, float]: + scale = quantize_fp8_tensorwise_get_scale(x) + q = quantize_fp8(x, scale) + return q, scale + +def quantize_fp8_axiswise(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: + scale = quantize_fp8_axiswise_get_scale(x, dim) + q = quantize_fp8(x, scale) + return q, scale + +def dequantize(q: Tensor, scale: float | Tensor, compute_dtype: torch.dtype) -> Tensor: + return q.to(compute_dtype) * scale.to(compute_dtype) From 5db4161d85308a00e8b34ac4306104ea4f0b0d4c Mon Sep 17 00:00:00 2001 From: dxqb Date: Fri, 7 Nov 2025 12:22:47 +0100 Subject: [PATCH 40/54] fix circular dependency --- modules/util/quantization_util.py | 113 +++++++++++++++--------------- 1 file changed, 58 insertions(+), 55 deletions(-) diff --git a/modules/util/quantization_util.py b/modules/util/quantization_util.py index 3b7b85baa..0a43d7c94 100644 --- a/modules/util/quantization_util.py +++ b/modules/util/quantization_util.py @@ -2,10 +2,6 @@ from collections.abc import Callable from functools import partial -from modules.module.quantized.LinearFp8 import LinearFp8 -from modules.module.quantized.LinearGGUFA8 import LinearGGUFA8 -from modules.module.quantized.LinearSVD import BaseLinearSVD, make_svd_linear -from modules.module.quantized.LinearW8A8 import LinearW8A8 from modules.module.quantized.mixin.QuantizedLinearMixin import QuantizedLinearMixin from modules.module.quantized.mixin.QuantizedModuleMixin import QuantizedModuleMixin from modules.util.config.TrainConfig import TrainConfig @@ -27,6 +23,64 @@ bnb = None LinearNf4 = None +def quantize_int8(x: Tensor, scale: float | Tensor) -> Tensor: + q = x.float().mul(1.0 / scale).round_().clamp_(-128.0, 127.0).to(torch.int8) + return q + +def quantize_int8_tensorwise_get_scale(x: Tensor) -> float: + abs_max = x.abs().max() + scale = (abs_max.float() / 127.0).clamp(min=1e-30) + return scale + +def quantize_int8_tensorwise(x: Tensor) -> tuple[Tensor, float]: + scale = quantize_int8_tensorwise_get_scale(x) + q = quantize_int8(x, scale) + return q, scale + +def quantize_int8_axiswise_get_scale(x: Tensor, dim: int) -> Tensor: + abs_max = x.abs().amax(dim=dim, keepdim=True) + scale = (abs_max.float() / 127.0).clamp(min=1e-30) + return scale + +def quantize_int8_axiswise(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: + scale = quantize_int8_axiswise_get_scale(x, dim) + q = quantize_int8(x, scale) + return q, scale + +def quantize_fp8(x: Tensor, scale: float | Tensor) -> Tensor: + q = x.float().mul(1.0 / scale).clamp_(-448.0, 448.0).to(torch.float8_e4m3fn) + return q + +def quantize_fp8_tensorwise_get_scale(x: Tensor) -> float: + abs_max = x.abs().max() + scale = (abs_max.float() / 448.0).clamp(min=1e-30) + return scale + +def quantize_fp8_axiswise_get_scale(x: Tensor, dim: int) -> Tensor: + abs_max = x.abs().amax(dim=dim, keepdim=True) + scale = (abs_max.float() / 448.0).clamp(min=1e-30) + return scale + +def quantize_fp8_tensorwise(x: Tensor) -> tuple[Tensor, float]: + scale = quantize_fp8_tensorwise_get_scale(x) + q = quantize_fp8(x, scale) + return q, scale + +def quantize_fp8_axiswise(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: + scale = quantize_fp8_axiswise_get_scale(x, dim) + q = quantize_fp8(x, scale) + return q, scale + +def dequantize(q: Tensor, scale: float | Tensor, compute_dtype: torch.dtype) -> Tensor: + return q.to(compute_dtype) * scale.to(compute_dtype) + + +from modules.module.quantized.LinearFp8 import LinearFp8 +from modules.module.quantized.LinearGGUFA8 import LinearGGUFA8 +from modules.module.quantized.LinearSVD import BaseLinearSVD, make_svd_linear +from modules.module.quantized.LinearW8A8 import LinearW8A8 + + def __create_linear_layer(construct_fn, module: nn.Linear, copy_parameters: bool) -> nn.Module: bias = module.bias is not None quant_linear = construct_fn( @@ -242,54 +296,3 @@ def offload_quantized( new_tensor = allocator(tensor) new_tensor.copy_(tensor.data, non_blocking=non_blocking) tensor.data = new_tensor - -def quantize_int8(x: Tensor, scale: float | Tensor) -> Tensor: - q = x.float().mul(1.0 / scale).round_().clamp_(-128.0, 127.0).to(torch.int8) - return q - -def quantize_int8_tensorwise_get_scale(x: Tensor) -> float: - abs_max = x.abs().max() - scale = (abs_max.float() / 127.0).clamp(min=1e-30) - return scale - -def quantize_int8_tensorwise(x: Tensor) -> tuple[Tensor, float]: - scale = quantize_int8_tensorwise_get_scale(x) - q = quantize_int8(x, scale) - return q, scale - -def quantize_int8_axiswise_get_scale(x: Tensor, dim: int) -> Tensor: - abs_max = x.abs().amax(dim=dim, keepdim=True) - scale = (abs_max.float() / 127.0).clamp(min=1e-30) - return scale - -def quantize_int8_axiswise(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: - scale = quantize_int8_axiswise_get_scale(x, dim) - q = quantize_int8(x, scale) - return q, scale - -def quantize_fp8(x: Tensor, scale: float | Tensor) -> Tensor: - q = x.float().mul(1.0 / scale).clamp_(-448.0, 448.0).to(torch.float8_e4m3fn) - return q - -def quantize_fp8_tensorwise_get_scale(x: Tensor) -> float: - abs_max = x.abs().max() - scale = (abs_max.float() / 448.0).clamp(min=1e-30) - return scale - -def quantize_fp8_axiswise_get_scale(x: Tensor, dim: int) -> Tensor: - abs_max = x.abs().amax(dim=dim, keepdim=True) - scale = (abs_max.float() / 448.0).clamp(min=1e-30) - return scale - -def quantize_fp8_tensorwise(x: Tensor) -> tuple[Tensor, float]: - scale = quantize_fp8_tensorwise_get_scale(x) - q = quantize_fp8(x, scale) - return q, scale - -def quantize_fp8_axiswise(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: - scale = quantize_fp8_axiswise_get_scale(x, dim) - q = quantize_fp8(x, scale) - return q, scale - -def dequantize(q: Tensor, scale: float | Tensor, compute_dtype: torch.dtype) -> Tensor: - return q.to(compute_dtype) * scale.to(compute_dtype) From cd3f971b4f96631a1ddbe1647c2cbd49de2e60db Mon Sep 17 00:00:00 2001 From: dxqb Date: Fri, 7 Nov 2025 14:33:53 +0100 Subject: [PATCH 41/54] ensure contiguous grad output --- modules/module/quantized/LinearGGUFA8.py | 4 ++-- modules/module/quantized/LinearW8A8.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/module/quantized/LinearGGUFA8.py b/modules/module/quantized/LinearGGUFA8.py index 90c3db79f..c8300e855 100644 --- a/modules/module/quantized/LinearGGUFA8.py +++ b/modules/module/quantized/LinearGGUFA8.py @@ -36,13 +36,13 @@ def fp8_forward_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> Tensor def int8_backward_axiswise(output: Tensor, weight: Tensor) -> Tensor: output_8, output_scale = quantize_int8_axiswise(output, dim=-1) w_8, w_scale = quantize_int8_axiswise(weight, dim=0) - mm_res = triton_mm_8bit(output_8, w_8) + mm_res = triton_mm_8bit(output_8.contiguous(), w_8) return mm_res.to(output.dtype).mul_(w_scale).mul_(output_scale) def fp8_backward_axiswise(output: Tensor, weight: Tensor) -> Tensor: output_8, output_scale = quantize_fp8_axiswise(output, dim=-1) w_8, w_scale = quantize_fp8_axiswise(weight, dim=0) - mm_res = triton_mm_8bit(output_8, w_8) + mm_res = triton_mm_8bit(output_8.contiguous(), w_8) return mm_res.to(output.dtype).mul_(w_scale).mul_(output_scale) class LinearGGUFIntA8RequantFunction(torch.autograd.Function): diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index d9f056bb6..3bda82008 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -33,12 +33,13 @@ def fp8_forward_tokenwise(x: Tensor, weight: Tensor, weight_scale: float, bias: def int8_backward_axiswise(output: Tensor, weight: Tensor, weight_scale: float) -> Tensor: output_8, output_scale = quantize_int8_axiswise(output, dim=-1) - mm_res = triton_mm_8bit(output_8, weight) + #almost always, grad outputs are already contiguous and this is a no-op. But there are some grad outputs from SDXL that are non-contiguous: + mm_res = triton_mm_8bit(output_8.contiguous(), weight) return mm_res.to(output.dtype).mul_(weight_scale * output_scale) def fp8_backward_axiswise(output: Tensor, weight: Tensor, weight_scale: float) -> Tensor: output_8, output_scale = quantize_fp8_axiswise(output, dim=-1) - mm_res = triton_mm_8bit(output_8, weight) + mm_res = triton_mm_8bit(output_8.contiguous(), weight) return mm_res.to(output.dtype).mul_(weight_scale * output_scale) From 41a05b94d79e0cec78ab3595028951fe632cf5e0 Mon Sep 17 00:00:00 2001 From: dxqb Date: Fri, 7 Nov 2025 14:39:37 +0100 Subject: [PATCH 42/54] W16A8 --- modules/module/quantized/LinearA8.py | 149 +++++++++++++++++++++++ modules/module/quantized/LinearGGUFA8.py | 45 ++----- modules/module/quantized/LinearW8A8.py | 1 - modules/ui/ModelTab.py | 4 + modules/util/enum/DataType.py | 4 + modules/util/quantization_util.py | 15 ++- 6 files changed, 174 insertions(+), 44 deletions(-) create mode 100644 modules/module/quantized/LinearA8.py diff --git a/modules/module/quantized/LinearA8.py b/modules/module/quantized/LinearA8.py new file mode 100644 index 000000000..745c49a9f --- /dev/null +++ b/modules/module/quantized/LinearA8.py @@ -0,0 +1,149 @@ +from modules.util.quantization_util import ( + quantize_fp8_axiswise, + quantize_int8_axiswise, +) +from modules.util.triton_mm_8bit import mm_8bit as triton_mm_8bit + +import torch +from torch import Tensor, nn + + +def int8_forward_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> Tensor: + x_8, x_scale = quantize_int8_axiswise(x, dim=-1) + w_8, w_scale = quantize_int8_axiswise(weight, dim=-1) + res = torch._int_mm(x_8, w_8.T) + res_scaled = res.to(x.dtype).mul_(w_scale.T).mul_(x_scale) + if bias is not None: + res_scaled.add_(bias.to(x.dtype)) + return res_scaled + +def fp8_forward_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> Tensor: + x_8, x_scale = quantize_fp8_axiswise(x, dim=-1) + w_8, w_scale = quantize_fp8_axiswise(weight, dim=-1) + one = torch.ones(1, device=x.device) + res = torch._scaled_mm(x_8, w_8.T, scale_a=one, scale_b=one, out_dtype=x.dtype) + res_scaled = res.mul_(w_scale.T).mul_(x_scale) + if bias is not None: + res_scaled.add_(bias.to(x.dtype)) + return res_scaled + +def int8_backward_act_axiswise(output: Tensor, weight: Tensor) -> Tensor: + output_8, output_scale = quantize_int8_axiswise(output, dim=-1) + w_8, w_scale = quantize_int8_axiswise(weight, dim=0) + #almost always, grad outputs are already contiguous and this is a no-op. But there are some grad outputs from SDXL that are non-contiguous: + output_8 = output_8.contiguous() + mm_res = triton_mm_8bit(output_8, w_8) + return mm_res.to(output.dtype).mul_(w_scale).mul_(output_scale) + +def fp8_backward_act_axiswise(output: Tensor, weight: Tensor) -> Tensor: + output_8, output_scale = quantize_fp8_axiswise(output, dim=-1) + w_8, w_scale = quantize_fp8_axiswise(weight, dim=0) + mm_res = triton_mm_8bit(output_8.contiguous(), w_8) + return mm_res.to(output.dtype).mul_(w_scale).mul_(output_scale) + +def int8_backward_weight_axiswise(output: Tensor, x: Tensor) -> Tensor: + output_8, output_scale = quantize_int8_axiswise(output, dim=0) + x_8, x_scale = quantize_int8_axiswise(x, dim=0) + #TODO could be more efficient using a kernel that accepts a non-contiguous lhs matrix + mm_res = triton_mm_8bit(output_8.T.contiguous(), x_8) + return mm_res.to(x.dtype).mul_(output_scale.T).mul_(x_scale) + +def fp8_backward_weight_axiswise(output: Tensor, x: Tensor) -> Tensor: + output_8, output_scale = quantize_fp8_axiswise(output, dim=0) + x_8, x_scale = quantize_fp8_axiswise(x, dim=0) + mm_res = triton_mm_8bit(output_8.T.contiguous(), x_8) + return mm_res.to(x.dtype).mul_(output_scale.T).mul_(x_scale) + +class LinearInt8Function(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: + ctx.save_for_backward(x, weight) + #axiswise performs better than tensorwise in tests, even though + #it requires another quant during backward - but quant is cheap + return int8_forward_axiswise(x, weight, bias) + + @staticmethod + def backward(ctx, grad_output: Tensor): + x, weight = ctx.saved_tensors + + grad_x, grad_weight, grad_bias = None, None, None + if ctx.needs_input_grad[0]: + # grad_output @ weight.T + grad_x = int8_backward_act_axiswise(grad_output, weight) + if ctx.needs_input_grad[1]: + # grad_output.T @ x + grad_weight = int8_backward_weight_axiswise(grad_output, x) + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + return grad_x, grad_weight, grad_bias + +class LinearFp8Function(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: + ctx.save_for_backward(x, weight) + return fp8_forward_axiswise(x, weight, bias) + + @staticmethod + def backward(ctx, grad_output: Tensor): + x, weight = ctx.saved_tensors + + grad_x, grad_weight, grad_bias = None, None, None + if ctx.needs_input_grad[0]: + # grad_output @ weight.T + grad_x = fp8_backward_act_axiswise(grad_output, weight) + if ctx.needs_input_grad[1]: + # grad_output.T @ x + grad_weight = fp8_backward_weight_axiswise(grad_output, x) + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + return grad_x, grad_weight, grad_bias + + +class LinearA8(nn.Linear): + def __init__(self, dtype, *args, **kwargs): + super().__init__(*args, **kwargs) + + assert dtype in [torch.int8, torch.float8_e4m3fn] + self._dtype = dtype + def forward(self, x_orig: torch.Tensor) -> torch.Tensor: + x = x_orig.to(self.weight.dtype).reshape(-1, x_orig.shape[-1]) + if x.shape[0] > 16: + if self._dtype == torch.int8: + y = LinearInt8Function.apply(x, self.weight, self.bias) + else: + y = LinearFp8Function.apply(x, self.weight, self.bias) + return y.reshape(x_orig.shape[:-1] + (y.shape[-1], )) + else: + return super().forward(x_orig) + + + +def run_benchmark(fn, desc, steps=10000, warmup=500, compile=False): + if compile: + fn = torch.compile(fn, fullgraph=True) + from tqdm import tqdm + for _ in range(warmup): + fn() + torch.cuda.synchronize() + for _ in tqdm(range(steps), desc=desc): + fn() + torch.cuda.synchronize() + + +@torch.no_grad() +def benchmark(m, k, n, device = 'cuda'): + output = torch.randn(m, n, device=device, dtype=torch.bfloat16) + x = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + + run_benchmark(lambda: int8_forward_axiswise(x, weight), "forward int8", compile=True) + run_benchmark(lambda: int8_backward_weight_axiswise(output, x), "backward weight int8", compile=True) + run_benchmark(lambda: fp8_forward_axiswise(x, weight), "forward fp8", compile=True) + run_benchmark(lambda: fp8_backward_weight_axiswise(output, x), "backward weight fp8", compile=True) + + +if __name__ == "__main__": + benchmark(2 * 1024 + 50, 3072, 3072 + 16) + #benchmark_fp8(2080, 3072) diff --git a/modules/module/quantized/LinearGGUFA8.py b/modules/module/quantized/LinearGGUFA8.py index c8300e855..1a85eeb3f 100644 --- a/modules/module/quantized/LinearGGUFA8.py +++ b/modules/module/quantized/LinearGGUFA8.py @@ -1,8 +1,9 @@ -from modules.util.quantization_util import ( - quantize_fp8_axiswise, - quantize_int8_axiswise, +from modules.module.quantized.LinearA8 import ( + fp8_backward_act_axiswise, + fp8_forward_axiswise, + int8_backward_act_axiswise, + int8_forward_axiswise, ) -from modules.util.triton_mm_8bit import mm_8bit as triton_mm_8bit import torch from torch import Tensor @@ -13,38 +14,6 @@ UNQUANTIZED_TYPES = [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16] - -def int8_forward_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> Tensor: - x_8, x_scale = quantize_int8_axiswise(x, dim=-1) - w_8, w_scale = quantize_int8_axiswise(weight, dim=-1) - res = torch._int_mm(x_8, w_8.T) - res_scaled = res.to(x.dtype).mul_(w_scale.T).mul_(x_scale) - if bias is not None: - res_scaled.add_(bias.to(x.dtype)) - return res_scaled - -def fp8_forward_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> Tensor: - x_8, x_scale = quantize_fp8_axiswise(x, dim=-1) - w_8, w_scale = quantize_fp8_axiswise(weight, dim=-1) - one = torch.ones(1, device=x.device) - res = torch._scaled_mm(x_8, w_8.T, scale_a=one, scale_b=one, out_dtype=x.dtype) - res_scaled = res.mul_(w_scale.T).mul_(x_scale) #much faster than scaled by _scaled_mm - if bias is not None: - res_scaled.add_(bias.to(x.dtype)) - return res_scaled - -def int8_backward_axiswise(output: Tensor, weight: Tensor) -> Tensor: - output_8, output_scale = quantize_int8_axiswise(output, dim=-1) - w_8, w_scale = quantize_int8_axiswise(weight, dim=0) - mm_res = triton_mm_8bit(output_8.contiguous(), w_8) - return mm_res.to(output.dtype).mul_(w_scale).mul_(output_scale) - -def fp8_backward_axiswise(output: Tensor, weight: Tensor) -> Tensor: - output_8, output_scale = quantize_fp8_axiswise(output, dim=-1) - w_8, w_scale = quantize_fp8_axiswise(weight, dim=0) - mm_res = triton_mm_8bit(output_8.contiguous(), w_8) - return mm_res.to(output.dtype).mul_(w_scale).mul_(output_scale) - class LinearGGUFIntA8RequantFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: @@ -58,7 +27,7 @@ def backward(ctx, output: Tensor): if ctx.needs_input_grad != (True, False, False): raise NotImplementedError("GGUF cannot be used for full finetuning") weight, = ctx.saved_tensors - return int8_backward_axiswise(output, weight), None, None + return int8_backward_act_axiswise(output, weight), None, None class LinearGGUFFpA8RequantFunction(torch.autograd.Function): @staticmethod @@ -71,7 +40,7 @@ def backward(ctx, output: Tensor): if ctx.needs_input_grad != (True, False, False): raise NotImplementedError("GGUF cannot be used for full finetuning") weight, = ctx.saved_tensors - return fp8_backward_axiswise(output, weight), None, None + return fp8_backward_act_axiswise(output, weight), None, None class LinearGGUFA8(GGUFLinear): def __init__(self, dtype: torch.dtype, *args, **kwargs): diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index 3bda82008..c58f6e34e 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -33,7 +33,6 @@ def fp8_forward_tokenwise(x: Tensor, weight: Tensor, weight_scale: float, bias: def int8_backward_axiswise(output: Tensor, weight: Tensor, weight_scale: float) -> Tensor: output_8, output_scale = quantize_int8_axiswise(output, dim=-1) - #almost always, grad outputs are already contiguous and this is a no-op. But there are some grad outputs from SDXL that are non-contiguous: mm_res = triton_mm_8bit(output_8.contiguous(), weight) return mm_res.to(output.dtype).mul_(weight_scale * output_scale) diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index 347bfb328..592ab028d 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -370,6 +370,10 @@ def __create_dtype_options(self, include_none: bool=True, include_gguf: bool=Fal ("float W8A8", DataType.FLOAT_W8A8), ("int W8A8", DataType.INT_W8A8), # ("int8", DataType.INT_8), # TODO: reactivate when the int8 implementation is fixed in bitsandbytes: https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1332 + ("bfloat16 A8 int", DataType.BFLOAT_16_A8_INT), + ("bfloat16 A8 float", DataType.BFLOAT_16_A8_FLOAT), + ("float16 A8 int", DataType.FLOAT_16_A8_INT), + ("float16 A8 float", DataType.FLOAT_16_A8_FLOAT), ("nfloat4", DataType.NFLOAT_4), ] diff --git a/modules/util/enum/DataType.py b/modules/util/enum/DataType.py index 25b314d8d..f1b360118 100644 --- a/modules/util/enum/DataType.py +++ b/modules/util/enum/DataType.py @@ -21,6 +21,10 @@ class DataType(Enum): GGUF = 'GGUF' GGUF_A8_FLOAT = 'GGUF_A8_FLOAT' GGUF_A8_INT = 'GGUF_A8_INT' + FLOAT_16_A8_FLOAT = 'FLOAT_16 A8_FLOAT' + FLOAT_16_A8_INT = 'FLOAT_16 A8_INT' + BFLOAT_16_A8_FLOAT = 'BFLOAT_16_A8_FLOAT' + BFLOAT_16_A8_INT = 'BFLOAT_16_A8_INT' def __str__(self): return self.value diff --git a/modules/util/quantization_util.py b/modules/util/quantization_util.py index 0a43d7c94..a89034d93 100644 --- a/modules/util/quantization_util.py +++ b/modules/util/quantization_util.py @@ -75,6 +75,7 @@ def dequantize(q: Tensor, scale: float | Tensor, compute_dtype: torch.dtype) -> return q.to(compute_dtype) * scale.to(compute_dtype) +from modules.module.quantized.LinearA8 import LinearA8 from modules.module.quantized.LinearFp8 import LinearFp8 from modules.module.quantized.LinearGGUFA8 import LinearGGUFA8 from modules.module.quantized.LinearSVD import BaseLinearSVD, make_svd_linear @@ -177,13 +178,17 @@ def replace_linear_with_quantized_layers( elif dtype.quantize_fp8(): construct_fn = make_svd_linear(LinearFp8) if dtype.quantize_svd() else LinearFp8 elif dtype.quantize_intW8A8(): - construct_fn = partial(make_svd_linear(LinearW8A8) if dtype.quantize_svd() else LinearW8A8, dtype=torch.int8, compute_dtype=torch.bfloat16) + construct_fn = partial(make_svd_linear(LinearW8A8) if dtype.quantize_svd() else LinearW8A8, dtype=torch.int8, compute_dtype=torch.bfloat16) #FIXME elif dtype.quantize_fpW8A8(): - construct_fn = partial(make_svd_linear(LinearW8A8) if dtype.quantize_svd() else LinearW8A8, dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + construct_fn = partial(make_svd_linear(LinearW8A8) if dtype.quantize_svd() else LinearW8A8, dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) #FIXME elif dtype == DataType.GGUF_A8_INT: - construct_fn = partial(LinearGGUFA8, dtype=torch.int8, compute_dtype=torch.bfloat16) + construct_fn = partial(LinearGGUFA8, dtype=torch.int8, compute_dtype=torch.bfloat16) #FIXME elif dtype == DataType.GGUF_A8_FLOAT: - construct_fn = partial(LinearGGUFA8, dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + construct_fn = partial(LinearGGUFA8, dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) #FIXME + elif dtype == DataType.BFLOAT_16_A8_INT or dtype == DataType.FLOAT_16_A8_INT: + construct_fn = partial(LinearA8, dtype=torch.int8) + elif dtype == DataType.BFLOAT_16_A8_FLOAT or dtype == DataType.FLOAT_16_A8_FLOAT: + construct_fn = partial(LinearA8, dtype=torch.float8_e4m3fn) else: return @@ -201,7 +206,7 @@ def replace_linear_with_quantized_layers( #https://github.com/Nerogar/OneTrainer/issues/1050 for name, module in parent_module.named_modules(): assert (not isinstance(module, convert_type) - or isinstance(module, (QuantizedLinearMixin, LinearGGUFA8)) + or isinstance(module, (QuantizedLinearMixin, LinearGGUFA8, LinearA8)) or any(s in name.split('.') for s in keep_in_fp32_modules) or (filters is not None and len(filters) > 0 and not any(f.matches(name) for f in filters)) ), f"Linear layer {name} was not found in model for quantization" From 228c976ff732644d8dc548d887e88fe0e9cf341c Mon Sep 17 00:00:00 2001 From: dxqb Date: Fri, 7 Nov 2025 15:00:42 +0100 Subject: [PATCH 43/54] fix comment --- modules/module/quantized/LinearA8.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/module/quantized/LinearA8.py b/modules/module/quantized/LinearA8.py index 745c49a9f..c10734b47 100644 --- a/modules/module/quantized/LinearA8.py +++ b/modules/module/quantized/LinearA8.py @@ -60,6 +60,8 @@ def forward(ctx, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: ctx.save_for_backward(x, weight) #axiswise performs better than tensorwise in tests, even though #it requires another quant during backward - but quant is cheap + + # x @ weight.T + bias return int8_forward_axiswise(x, weight, bias) @staticmethod @@ -68,7 +70,7 @@ def backward(ctx, grad_output: Tensor): grad_x, grad_weight, grad_bias = None, None, None if ctx.needs_input_grad[0]: - # grad_output @ weight.T + # grad_output @ weight grad_x = int8_backward_act_axiswise(grad_output, weight) if ctx.needs_input_grad[1]: # grad_output.T @ x @@ -90,7 +92,7 @@ def backward(ctx, grad_output: Tensor): grad_x, grad_weight, grad_bias = None, None, None if ctx.needs_input_grad[0]: - # grad_output @ weight.T + # grad_output @ weight grad_x = fp8_backward_act_axiswise(grad_output, weight) if ctx.needs_input_grad[1]: # grad_output.T @ x From a030b2345d0b9966b4a6ab43a23e8da7c7047374 Mon Sep 17 00:00:00 2001 From: dxqb Date: Fri, 7 Nov 2025 21:02:52 +0100 Subject: [PATCH 44/54] DataType bugfix --- modules/util/enum/DataType.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/util/enum/DataType.py b/modules/util/enum/DataType.py index f1b360118..34b7a9816 100644 --- a/modules/util/enum/DataType.py +++ b/modules/util/enum/DataType.py @@ -21,8 +21,8 @@ class DataType(Enum): GGUF = 'GGUF' GGUF_A8_FLOAT = 'GGUF_A8_FLOAT' GGUF_A8_INT = 'GGUF_A8_INT' - FLOAT_16_A8_FLOAT = 'FLOAT_16 A8_FLOAT' - FLOAT_16_A8_INT = 'FLOAT_16 A8_INT' + FLOAT_16_A8_FLOAT = 'FLOAT_16_A8_FLOAT' + FLOAT_16_A8_INT = 'FLOAT_16_A8_INT' BFLOAT_16_A8_FLOAT = 'BFLOAT_16_A8_FLOAT' BFLOAT_16_A8_INT = 'BFLOAT_16_A8_INT' From 9159d247ed80337c39690de15fd8d612f4a75aca Mon Sep 17 00:00:00 2001 From: dxqb Date: Sat, 8 Nov 2025 15:04:14 +0100 Subject: [PATCH 45/54] avoid attention mask --- modules/modelSetup/BaseChromaSetup.py | 2 +- modules/modelSetup/BaseQwenSetup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/modelSetup/BaseChromaSetup.py b/modules/modelSetup/BaseChromaSetup.py index 183ea3a2f..e3ab10f81 100644 --- a/modules/modelSetup/BaseChromaSetup.py +++ b/modules/modelSetup/BaseChromaSetup.py @@ -228,7 +228,7 @@ def predict( image_seq_len = packed_latent_input.shape[1] image_attention_mask = torch.full((packed_latent_input.shape[0], image_seq_len), True, dtype=torch.bool, device=text_attention_mask.device) - attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) + attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) if not torch.all(text_attention_mask) else None packed_predicted_flow = model.transformer( hidden_states=packed_latent_input.to(dtype=model.train_dtype.torch_dtype()), diff --git a/modules/modelSetup/BaseQwenSetup.py b/modules/modelSetup/BaseQwenSetup.py index 93984fd95..6aeab5c20 100644 --- a/modules/modelSetup/BaseQwenSetup.py +++ b/modules/modelSetup/BaseQwenSetup.py @@ -147,7 +147,7 @@ def predict( #FIXME bug workaround for https://github.com/huggingface/diffusers/issues/12294 image_attention_mask=torch.ones((packed_latent_input.shape[0], packed_latent_input.shape[1]), dtype=torch.bool, device=latent_image.device) attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) - attention_mask_2d = attention_mask[:, None, None, :] * attention_mask[:, None, :, None] + attention_mask_2d = attention_mask[:, None, None, :] if not torch.all(text_attention_mask) else None packed_predicted_flow = model.transformer( hidden_states=packed_latent_input.to(dtype=model.train_dtype.torch_dtype()), From 9559ebfe11566e4ca6f28590eab3210e9d733703 Mon Sep 17 00:00:00 2001 From: dxqb Date: Sat, 8 Nov 2025 22:36:58 +0100 Subject: [PATCH 46/54] disable bug workaround - can currently not be reproduced and because of #1109 --- modules/model/ChromaModel.py | 5 +++-- modules/model/QwenModel.py | 5 +++-- modules/modelSetup/BaseChromaSetup.py | 6 ++++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/modules/model/ChromaModel.py b/modules/model/ChromaModel.py index 14a0b6ffe..3c8fe6e9a 100644 --- a/modules/model/ChromaModel.py +++ b/modules/model/ChromaModel.py @@ -218,9 +218,10 @@ def encode_text( seq_lengths = bool_attention_mask.sum(dim=1) max_seq_length = seq_lengths.max().item() - if max_seq_length % 16 > 0: + #TODO disabled because of https://github.com/Nerogar/OneTrainer/pull/1109, but it could trigger https://github.com/pytorch/pytorch/issues/165506 again + #if max_seq_length % 16 > 0: #attention processors and/or torch.compile can have issues with uneven sequence length: - max_seq_length += (16 - max_seq_length % 16) + # max_seq_length += (16 - max_seq_length % 16) text_encoder_output = text_encoder_output[:, :max_seq_length, :] bool_attention_mask = bool_attention_mask[:, :max_seq_length] diff --git a/modules/model/QwenModel.py b/modules/model/QwenModel.py index 4c72e0fda..67d3bacf2 100644 --- a/modules/model/QwenModel.py +++ b/modules/model/QwenModel.py @@ -174,9 +174,10 @@ def encode_text( seq_lengths = tokens_mask.sum(dim=1) max_seq_length = seq_lengths.max().item() - if max_seq_length % 16 > 0: + #TODO disabled because of https://github.com/Nerogar/OneTrainer/pull/1109, but it could trigger https://github.com/pytorch/pytorch/issues/165506 again + #if max_seq_length % 16 > 0: #attention processors and/or torch.compile can have issues with uneven sequence length: - max_seq_length += (16 - max_seq_length % 16) + # max_seq_length += (16 - max_seq_length % 16) text_encoder_output = text_encoder_output[:, :max_seq_length, :] bool_attention_mask = tokens_mask[:, :max_seq_length].bool() diff --git a/modules/modelSetup/BaseChromaSetup.py b/modules/modelSetup/BaseChromaSetup.py index 1fdbbbf54..f4331fd00 100644 --- a/modules/modelSetup/BaseChromaSetup.py +++ b/modules/modelSetup/BaseChromaSetup.py @@ -226,11 +226,13 @@ def predict( packed_latent_input = model.pack_latents(latent_input) image_seq_len = packed_latent_input.shape[1] - text_seq_len = text_encoder_output.shape[1] image_attention_mask = torch.full((packed_latent_input.shape[0], image_seq_len), True, dtype=torch.bool, device=text_attention_mask.device) attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) if not torch.all(text_attention_mask) else None - assert image_seq_len % 16 == 0 and (image_seq_len + text_seq_len) % 16 == 0 + #TODO disabled because of https://github.com/Nerogar/OneTrainer/pull/1109, but it could trigger https://github.com/pytorch/pytorch/issues/165506 again + #text_seq_len = text_encoder_output.shape[1] + #assert image_seq_len % 16 == 0 and (image_seq_len + text_seq_len) % 16 == 0 + packed_predicted_flow = model.transformer( hidden_states=packed_latent_input.to(dtype=model.train_dtype.torch_dtype()), timestep=timestep / 1000, From 49d2bc4104864192453aefa1dcdd378b186107e9 Mon Sep 17 00:00:00 2001 From: dxqb Date: Sun, 9 Nov 2025 04:27:10 +0100 Subject: [PATCH 47/54] pad sequence length if an attention mask is necessary anyway --- modules/model/ChromaModel.py | 9 +++++---- modules/model/QwenModel.py | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/modules/model/ChromaModel.py b/modules/model/ChromaModel.py index 3c8fe6e9a..d62a8d21b 100644 --- a/modules/model/ChromaModel.py +++ b/modules/model/ChromaModel.py @@ -218,10 +218,11 @@ def encode_text( seq_lengths = bool_attention_mask.sum(dim=1) max_seq_length = seq_lengths.max().item() - #TODO disabled because of https://github.com/Nerogar/OneTrainer/pull/1109, but it could trigger https://github.com/pytorch/pytorch/issues/165506 again - #if max_seq_length % 16 > 0: - #attention processors and/or torch.compile can have issues with uneven sequence length: - # max_seq_length += (16 - max_seq_length % 16) + #pad to 16 because attention processors and/or torch.compile can have issues with uneven sequence lengths, but only pad if an attention mask has to be used anyway: + #TODO the second condition could trigger https://github.com/pytorch/pytorch/issues/165506 again, but try like this because no attention mask + #is preferable: https://github.com/Nerogar/OneTrainer/pull/1109 + if max_seq_length % 16 > 0 and (seq_lengths != max_seq_length).any(): + max_seq_length += (16 - max_seq_length % 16) text_encoder_output = text_encoder_output[:, :max_seq_length, :] bool_attention_mask = bool_attention_mask[:, :max_seq_length] diff --git a/modules/model/QwenModel.py b/modules/model/QwenModel.py index 67d3bacf2..afa6c24fe 100644 --- a/modules/model/QwenModel.py +++ b/modules/model/QwenModel.py @@ -174,10 +174,11 @@ def encode_text( seq_lengths = tokens_mask.sum(dim=1) max_seq_length = seq_lengths.max().item() - #TODO disabled because of https://github.com/Nerogar/OneTrainer/pull/1109, but it could trigger https://github.com/pytorch/pytorch/issues/165506 again - #if max_seq_length % 16 > 0: - #attention processors and/or torch.compile can have issues with uneven sequence length: - # max_seq_length += (16 - max_seq_length % 16) + #pad to 16 because attention processors and/or torch.compile can have issues with uneven sequence lengths, but only pad if an attention mask has to be used anyway: + #TODO the second condition could trigger https://github.com/pytorch/pytorch/issues/165506 again, but try like this because no attention mask + #is preferable: https://github.com/Nerogar/OneTrainer/pull/1109 + if max_seq_length % 16 > 0 and (seq_lengths != max_seq_length).any(): + max_seq_length += (16 - max_seq_length % 16) text_encoder_output = text_encoder_output[:, :max_seq_length, :] bool_attention_mask = tokens_mask[:, :max_seq_length].bool() From 69f0fa17ce572add1f1fbe7acde1f1014792c9c3 Mon Sep 17 00:00:00 2001 From: dxqb Date: Fri, 14 Nov 2025 19:51:26 +0100 Subject: [PATCH 48/54] merge fix --- modules/ui/ModelTab.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index 90ca1fe6a..615f34d80 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -452,9 +452,9 @@ def __create_base_components( row += 1 # compile - components.label(self.scroll_frame, row, 3, "Compile transformer blocks", + components.label(frame, row, 3, "Compile transformer blocks", tooltip="Uses torch.compile and Triton to significantly speed up training. Only applies to transformer/unet. Disable in case of compatibility issues.") - components.switch(self.scroll_frame, row, 4, self.ui_state, "compile") + components.switch(frame, row, 4, self.ui_state, "compile") row += 1 From 1739bfedb53966a7c745f97a19dc22c411245006 Mon Sep 17 00:00:00 2001 From: O-J1 <18110006+O-J1@users.noreply.github.com> Date: Mon, 17 Nov 2025 05:38:20 +1100 Subject: [PATCH 49/54] Fixes [Bug]: Layer filter isn't configured correct if a preset is loaded Fixes #1089 Fixes Additional embeddings tab not loading. --- modules/ui/AdditionalEmbeddingsTab.py | 4 ++++ modules/ui/TrainingTab.py | 26 +++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/modules/ui/AdditionalEmbeddingsTab.py b/modules/ui/AdditionalEmbeddingsTab.py index 465f3ac16..d6cef6ca8 100644 --- a/modules/ui/AdditionalEmbeddingsTab.py +++ b/modules/ui/AdditionalEmbeddingsTab.py @@ -24,6 +24,10 @@ def __init__(self, master, train_config: TrainConfig, ui_state: UIState): ) def refresh_ui(self): + if self.element_list is not None: + self.element_list.destroy() + self.element_list = None + self.widgets_initialized = False self._create_element_list() def create_widget(self, master, element, i, open_command, remove_command, clone_command, save_command): diff --git a/modules/ui/TrainingTab.py b/modules/ui/TrainingTab.py index e20f1de71..308ce28ca 100644 --- a/modules/ui/TrainingTab.py +++ b/modules/ui/TrainingTab.py @@ -53,6 +53,10 @@ def __init__(self, master, train_config: TrainConfig, ui_state: UIState): self.presets_list = [] self.prior_custom = "" self.prior_selected = None + self.layer_filter_trace_id = self.ui_state.add_var_trace( + "layer_filter_preset", + self.__on_layer_filter_preset_change, + ) self.scroll_frame = None @@ -793,7 +797,7 @@ def __create_layer_frame(self, master, row): frame.grid_columnconfigure(0, weight=1) components.label(frame, 0, 0, "Layer Filter", - tooltip="Select a preset defining which layers to train, or select 'Custom' to define your own. A blank custom field will train all layers.") + tooltip="Select a preset defining which layers to train, or select 'Custom' to define your own. \n \n A blank custom field or 'Full' will train all layers. Custom/Full is not necessarily recommended nor supported, intended for advanced users.") self.layer_selector = components.options( frame, 0, 1, self.presets_list, self.ui_state, "layer_filter_preset", command=self.__preset_set_layer_choice @@ -835,6 +839,7 @@ def __preset_set_layer_choice(self, selected: str): if selected == "custom": # Restore prior custom text and allow editing + regex toggle + self.__show_layer_entry() self.layer_entry.configure(state="normal", fg_color=self.layer_entry_fg_color, text_color=self.layer_entry_text_color) self.layer_entry.cget('textvariable').set(self.prior_custom) self.regex_label.grid() @@ -866,8 +871,27 @@ def __preset_set_layer_choice(self, selected: str): self.regex_label.grid_remove() self.regex_switch.grid_remove() + if selected == "full" and not patterns: + self.__hide_layer_entry() + else: + self.__show_layer_entry() + self.prior_selected = selected + def __on_layer_filter_preset_change(self): + if not self.layer_selector: + return + selected = self.ui_state.get_var("layer_filter_preset").get() + self.__preset_set_layer_choice(selected) + + def __hide_layer_entry(self): + if self.layer_entry and self.layer_entry.winfo_manager(): + self.layer_entry.grid_remove() + + def __show_layer_entry(self): + if self.layer_entry and not self.layer_entry.winfo_manager(): + self.layer_entry.grid() + def __open_optimizer_params_window(self): window = OptimizerParamsWindow(self.master, self.train_config, self.ui_state) self.master.wait_window(window) From 968b2a93cd463fee691fc5a468487bbf44215f91 Mon Sep 17 00:00:00 2001 From: dxqb <183307934+dxqb@users.noreply.github.com> Date: Sun, 16 Nov 2025 20:02:48 +0100 Subject: [PATCH 50/54] Simplify tooltip text for layer filter --- modules/ui/TrainingTab.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui/TrainingTab.py b/modules/ui/TrainingTab.py index 308ce28ca..905527b86 100644 --- a/modules/ui/TrainingTab.py +++ b/modules/ui/TrainingTab.py @@ -797,7 +797,7 @@ def __create_layer_frame(self, master, row): frame.grid_columnconfigure(0, weight=1) components.label(frame, 0, 0, "Layer Filter", - tooltip="Select a preset defining which layers to train, or select 'Custom' to define your own. \n \n A blank custom field or 'Full' will train all layers. Custom/Full is not necessarily recommended nor supported, intended for advanced users.") + tooltip="Select a preset defining which layers to train, or select 'Custom' to define your own.\nA blank custom field or 'Full' will train all layers.") self.layer_selector = components.options( frame, 0, 1, self.presets_list, self.ui_state, "layer_filter_preset", command=self.__preset_set_layer_choice From 6f2d2f5bd45c5e3adf488c5daddfbb5e31c06976 Mon Sep 17 00:00:00 2001 From: O-J1 <18110006+O-J1@users.noreply.github.com> Date: Mon, 17 Nov 2025 06:08:17 +1100 Subject: [PATCH 51/54] Tweak tooltip text a little more --- modules/ui/TrainingTab.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui/TrainingTab.py b/modules/ui/TrainingTab.py index 905527b86..cb769227c 100644 --- a/modules/ui/TrainingTab.py +++ b/modules/ui/TrainingTab.py @@ -797,7 +797,7 @@ def __create_layer_frame(self, master, row): frame.grid_columnconfigure(0, weight=1) components.label(frame, 0, 0, "Layer Filter", - tooltip="Select a preset defining which layers to train, or select 'Custom' to define your own.\nA blank custom field or 'Full' will train all layers.") + tooltip="Select a preset defining which layers to train, or select 'Custom' to define your own.\nA blank 'custom' field or 'Full' will train all layers.") self.layer_selector = components.options( frame, 0, 1, self.presets_list, self.ui_state, "layer_filter_preset", command=self.__preset_set_layer_choice From c9a7d074284b216be072508c82d973c9deaeb8e0 Mon Sep 17 00:00:00 2001 From: dxqb Date: Sat, 22 Nov 2025 17:03:13 +0100 Subject: [PATCH 52/54] fix to Dtypes, to avoid leaving weights at float32 --- modules/util/enum/DataType.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/modules/util/enum/DataType.py b/modules/util/enum/DataType.py index 34b7a9816..43cee577e 100644 --- a/modules/util/enum/DataType.py +++ b/modules/util/enum/DataType.py @@ -45,6 +45,14 @@ def torch_dtype( return torch.bfloat16 case DataType.TFLOAT_32: return torch.float32 + case DataType.FLOAT_16_A8_FLOAT: + return torch.float16 + case DataType.FLOAT_16_A8_INT: + return torch.float16 + case DataType.BFLOAT_16_A8_FLOAT: + return torch.bfloat16 + case DataType.BFLOAT_16_A8_INT: + return torch.bfloat16 case _: return None From 50982b7d58ed6a9d64ced46bb42125cc88adac8a Mon Sep 17 00:00:00 2001 From: dxqb Date: Sat, 22 Nov 2025 17:19:12 +0100 Subject: [PATCH 53/54] UI update --- .../modelLoader/mixin/HFModelLoaderMixin.py | 8 +- modules/ui/ModelTab.py | 99 +++++++++---------- modules/util/config/TrainConfig.py | 17 ++-- modules/util/ui/components.py | 6 +- 4 files changed, 64 insertions(+), 66 deletions(-) diff --git a/modules/modelLoader/mixin/HFModelLoaderMixin.py b/modules/modelLoader/mixin/HFModelLoaderMixin.py index 0cb3d3c92..f101b0cc7 100644 --- a/modules/modelLoader/mixin/HFModelLoaderMixin.py +++ b/modules/modelLoader/mixin/HFModelLoaderMixin.py @@ -122,10 +122,6 @@ def __load_sub_module( if hasattr(sub_module, '_fix_state_dict_keys_on_load'): sub_module._fix_state_dict_keys_on_load(state_dict) - #TODO why is it necessary to iterate by key names from the state dict? - #why not iterate through the object model, like replace_linear_... does? - #would avoid key replacements as follows. - if hasattr(sub_module, "_checkpoint_conversion_mapping"): #required for loading the text encoder of Qwen new_state_dict = {} for k, v in state_dict.items(): @@ -135,6 +131,10 @@ def __load_sub_module( new_state_dict[new_k] = v state_dict = new_state_dict + #this loads the actual data from the state dict into tensors that are 'meta' tensors up to this point + #tensors that will be quantized are loaded at their original dtype. non-quantized tensors are converted + #to their intended dtype here + #TODO why not quantize here? would avoid to load the entire model first (high RAM) and then quantize (low RAM) for key, value in state_dict.items(): module = sub_module tensor_name = key diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index 615f34d80..510f7bd02 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -78,55 +78,6 @@ def refresh_ui(self): elif self.train_config.model_type.is_hi_dream(): self.__setup_hi_dream_ui(base_frame) - self.__create_quantization_frame(self.scroll_frame, row=1, column=0) - - def __create_quantization_frame( - self, - master, - row: int, - column: int, - ): - frame = ctk.CTkFrame(master=master, corner_radius=5, width=300) - frame.grid(row=row, column=column, padx=5, pady=5, sticky="nsew") - frame.grid_columnconfigure(0, weight=1) - frame.grid_columnconfigure(1, weight=10) - - presets = [] - if self.train_config.model_type.is_stable_diffusion(): #TODO simplify and de-duplicate with layer filter on training tab - presets = sd_presets - elif self.train_config.model_type.is_stable_diffusion_xl(): - presets = sdxl_presets - elif self.train_config.model_type.is_stable_diffusion_3(): - presets = sd3_presets - elif self.train_config.model_type.is_wuerstchen(): - presets = sc_presets - elif self.train_config.model_type.is_pixart(): - presets = pixart_presets - elif self.train_config.model_type.is_flux(): - presets = flux_presets - elif self.train_config.model_type.is_qwen(): - presets = qwen_presets - elif self.train_config.model_type.is_chroma(): - presets = chroma_presets - elif self.train_config.model_type.is_sana(): - presets = sana_presets - elif self.train_config.model_type.is_hunyuan_video(): - presets = hunyuan_video_presets - elif self.train_config.model_type.is_hi_dream(): - presets = hidream_presets - else: - presets = {"full": []} - - components.layer_filter_entry(frame, 0, 0, self.ui_state, - preset_var_name="quantization_layer_filter_preset", presets=presets, - preset_label="Quantization Layer Filter", - preset_tooltip="Select a preset defining which layers to quantize. Quantization of certain layers can decrease model quality. Only applies to the transformer/unet", - entry_var_name="quantization_layer_filter", - entry_tooltip="Comma-separated list of layers to quantize. Regular expressions (if toggled) are supported. Any model layer with a matching name will be quantized", - regex_var_name="quantization_layer_filter_regex", - regex_tooltip="If enabled, layer filter patterns are interpreted as regular expressions. Otherwise, simple substring matching is used.", - ) - def __setup_stable_diffusion_ui(self, frame): row = 0 row = self.__create_base_dtype_components(frame, row) @@ -343,16 +294,20 @@ def __setup_hi_dream_ui(self, frame): allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __create_dtype_options(self, include_none:bool=True, include_gguf=False) -> list[tuple[str, DataType]]: + def __create_dtype_options(self, include_none:bool=True, include_gguf=False, include_quantization=True) -> list[tuple[str, DataType]]: options = [ ("float32", DataType.FLOAT_32), ("bfloat16", DataType.BFLOAT_16), ("float16", DataType.FLOAT_16), - ("float8", DataType.FLOAT_8), - # ("int8", DataType.INT_8), # TODO: reactivate when the int8 implementation is fixed in bitsandbytes: https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1332 - ("nfloat4", DataType.NFLOAT_4), ] + if include_quantization: + options += [ + ("float8", DataType.FLOAT_8), + # ("int8", DataType.INT_8), # TODO: reactivate when the int8 implementation is fixed in bitsandbytes: https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1332 + ("nfloat4", DataType.NFLOAT_4), + ] + if include_gguf: options.append(("GGUF", DataType.GGUF)) @@ -451,6 +406,44 @@ def __create_base_components( row += 1 + presets = [] + if self.train_config.model_type.is_stable_diffusion(): #TODO simplify and de-duplicate with layer filter on training tab + presets = sd_presets + elif self.train_config.model_type.is_stable_diffusion_xl(): + presets = sdxl_presets + elif self.train_config.model_type.is_stable_diffusion_3(): + presets = sd3_presets + elif self.train_config.model_type.is_wuerstchen(): + presets = sc_presets + elif self.train_config.model_type.is_pixart(): + presets = pixart_presets + elif self.train_config.model_type.is_flux(): + presets = flux_presets + elif self.train_config.model_type.is_qwen(): + presets = qwen_presets + elif self.train_config.model_type.is_chroma(): + presets = chroma_presets + elif self.train_config.model_type.is_sana(): + presets = sana_presets + elif self.train_config.model_type.is_hunyuan_video(): + presets = hunyuan_video_presets + elif self.train_config.model_type.is_hi_dream(): + presets = hidream_presets + else: + presets = {"full": []} + + components.label(frame, row, 0, "Quantization") + components.layer_filter_entry(frame, row, 1, self.ui_state, + preset_var_name="quantization_layer_filter_preset", presets=presets, + preset_label="Layer Filter", + preset_tooltip="Select a preset defining which layers to quantize. Quantization of certain layers can decrease model quality. Only applies to the transformer/unet", + entry_var_name="quantization_layer_filter", + entry_tooltip="Comma-separated list of layers to quantize. Regular expressions (if toggled) are supported. Any model layer with a matching name will be quantized", + regex_var_name="quantization_layer_filter_regex", + regex_tooltip="If enabled, layer filter patterns are interpreted as regular expressions. Otherwise, simple substring matching is used.", + frame_color="transparent", + ) + # compile components.label(frame, row, 3, "Compile transformer blocks", tooltip="Uses torch.compile and Triton to significantly speed up training. Only applies to transformer/unet. Disable in case of compatibility issues.") diff --git a/modules/util/config/TrainConfig.py b/modules/util/config/TrainConfig.py index 043ef7d01..90561d38a 100644 --- a/modules/util/config/TrainConfig.py +++ b/modules/util/config/TrainConfig.py @@ -409,6 +409,11 @@ class TrainConfig(BaseConfig): # transformer transformer: TrainModelPartConfig + # quantization + quantization_layer_filter: str + quantization_layer_filter_preset: str + quantization_layer_filter_regex: bool + # text encoder text_encoder: TrainModelPartConfig text_encoder_layer_skip: int @@ -971,7 +976,7 @@ def default_values() -> 'TrainConfig': prior.weight_dtype = DataType.NONE data.append(("prior", prior, TrainModelPartConfig, False)) - # prior + # transformer transformer = TrainModelPartConfig.default_values() transformer.model_name = "" transformer.train = True @@ -980,6 +985,11 @@ def default_values() -> 'TrainConfig': transformer.weight_dtype = DataType.NONE data.append(("transformer", transformer, TrainModelPartConfig, False)) + #quantization layer filter + data.append(("quantization_layer_filter", "", str, False)) + data.append(("quantization_layer_filter_preset", "full", str, False)) + data.append(("quantization_layer_filter_regex", False, bool, False)) + # text encoder text_encoder = TrainModelPartConfig.default_values() text_encoder.train = True @@ -1062,11 +1072,6 @@ def default_values() -> 'TrainConfig': data.append(("layer_filter_preset", "full", str, False)) data.append(("layer_filter_regex", False, bool, False)) - #quantization layer filter - data.append(("quantization_layer_filter", "", str, False)) - data.append(("quantization_layer_filter_preset", "full", str, False)) - data.append(("quantization_layer_filter_regex", False, bool, False)) - # embedding data.append(("embedding_learning_rate", None, float, True)) data.append(("preserve_embedding_norm", False, bool, False)) diff --git a/modules/util/ui/components.py b/modules/util/ui/components.py index 0281916f5..961d213a0 100644 --- a/modules/util/ui/components.py +++ b/modules/util/ui/components.py @@ -296,9 +296,9 @@ def time_entry(master, row, column, ui_state: UIState, var_name: str, unit_var_n return frame -def layer_filter_entry(master, row, column, ui_state: UIState, preset_var_name: str, preset_label: str, preset_tooltip: str, presets, entry_var_name, entry_tooltip: str, regex_var_name, regex_tooltip: str): - frame = ctk.CTkFrame(master=master, corner_radius=5) - frame.grid(row=row, column=0, padx=5, pady=5, sticky="nsew") +def layer_filter_entry(master, row, column, ui_state: UIState, preset_var_name: str, preset_label: str, preset_tooltip: str, presets, entry_var_name, entry_tooltip: str, regex_var_name, regex_tooltip: str, frame_color=None): + frame = ctk.CTkFrame(master=master, corner_radius=5, fg_color=frame_color) + frame.grid(row=row, column=column, padx=5, pady=5, sticky="nsew") frame.grid_columnconfigure(0, weight=1) layer_entry = entry( From 16a6015e369e6936f3007ea34ab94ef54b826396 Mon Sep 17 00:00:00 2001 From: dxqb Date: Sat, 22 Nov 2025 17:23:02 +0100 Subject: [PATCH 54/54] UI update --- modules/ui/ModelTab.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index 510f7bd02..bcef5764f 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -294,20 +294,16 @@ def __setup_hi_dream_ui(self, frame): allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __create_dtype_options(self, include_none:bool=True, include_gguf=False, include_quantization=True) -> list[tuple[str, DataType]]: + def __create_dtype_options(self, include_none:bool=True, include_gguf=False) -> list[tuple[str, DataType]]: options = [ ("float32", DataType.FLOAT_32), ("bfloat16", DataType.BFLOAT_16), ("float16", DataType.FLOAT_16), + ("float8", DataType.FLOAT_8), + # ("int8", DataType.INT_8), # TODO: reactivate when the int8 implementation is fixed in bitsandbytes: https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1332 + ("nfloat4", DataType.NFLOAT_4), ] - if include_quantization: - options += [ - ("float8", DataType.FLOAT_8), - # ("int8", DataType.INT_8), # TODO: reactivate when the int8 implementation is fixed in bitsandbytes: https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1332 - ("nfloat4", DataType.NFLOAT_4), - ] - if include_gguf: options.append(("GGUF", DataType.GGUF))