diff --git a/modules/modelLoader/ChromaEmbeddingModelLoader.py b/modules/modelLoader/ChromaEmbeddingModelLoader.py index ea03d5757..c113d0877 100644 --- a/modules/modelLoader/ChromaEmbeddingModelLoader.py +++ b/modules/modelLoader/ChromaEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class ChromaEmbeddingModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> ChromaModel | None: base_model_loader = ChromaModelLoader() embedding_loader = ChromaEmbeddingLoader() @@ -41,7 +43,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/ChromaFineTuneModelLoader.py b/modules/modelLoader/ChromaFineTuneModelLoader.py index c9bbb1891..0915d5509 100644 --- a/modules/modelLoader/ChromaFineTuneModelLoader.py +++ b/modules/modelLoader/ChromaFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class ChromaFineTuneModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> ChromaModel | None: base_model_loader = ChromaModelLoader() embedding_loader = ChromaEmbeddingLoader() @@ -41,7 +43,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/ChromaLoRAModelLoader.py b/modules/modelLoader/ChromaLoRAModelLoader.py index aaa38058c..37d07644a 100644 --- a/modules/modelLoader/ChromaLoRAModelLoader.py +++ b/modules/modelLoader/ChromaLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class ChromaLoRAModelLoader( @@ -33,6 +34,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> ChromaModel | None: base_model_loader = ChromaModelLoader() lora_model_loader = ChromaLoRALoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/FluxEmbeddingModelLoader.py b/modules/modelLoader/FluxEmbeddingModelLoader.py index 0960c051a..09baf3dbb 100644 --- a/modules/modelLoader/FluxEmbeddingModelLoader.py +++ b/modules/modelLoader/FluxEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class FluxEmbeddingModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> FluxModel | None: base_model_loader = FluxModelLoader() embedding_loader = FluxEmbeddingLoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/FluxFineTuneModelLoader.py b/modules/modelLoader/FluxFineTuneModelLoader.py index 4fac329a3..599cdfb10 100644 --- a/modules/modelLoader/FluxFineTuneModelLoader.py +++ b/modules/modelLoader/FluxFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class FluxFineTuneModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> FluxModel | None: base_model_loader = FluxModelLoader() embedding_loader = FluxEmbeddingLoader() @@ -43,7 +45,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/FluxLoRAModelLoader.py b/modules/modelLoader/FluxLoRAModelLoader.py index 84d9133dc..a7a554370 100644 --- a/modules/modelLoader/FluxLoRAModelLoader.py +++ b/modules/modelLoader/FluxLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class FluxLoRAModelLoader( @@ -35,6 +36,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> FluxModel | None: base_model_loader = FluxModelLoader() lora_model_loader = FluxLoRALoader() @@ -45,7 +47,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/HiDreamEmbeddingModelLoader.py b/modules/modelLoader/HiDreamEmbeddingModelLoader.py index 7539fdf70..4583db403 100644 --- a/modules/modelLoader/HiDreamEmbeddingModelLoader.py +++ b/modules/modelLoader/HiDreamEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class HiDreamEmbeddingModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> HiDreamModel | None: base_model_loader = HiDreamModelLoader() embedding_loader = HiDreamEmbeddingLoader() @@ -41,7 +43,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/HiDreamFineTuneModelLoader.py b/modules/modelLoader/HiDreamFineTuneModelLoader.py index d9de7fc7a..3d4b27cb4 100644 --- a/modules/modelLoader/HiDreamFineTuneModelLoader.py +++ b/modules/modelLoader/HiDreamFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class HiDreamFineTuneModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> HiDreamModel | None: base_model_loader = HiDreamModelLoader() embedding_loader = HiDreamEmbeddingLoader() @@ -41,7 +43,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/HiDreamLoRAModelLoader.py b/modules/modelLoader/HiDreamLoRAModelLoader.py index 524ab3e89..90898fc73 100644 --- a/modules/modelLoader/HiDreamLoRAModelLoader.py +++ b/modules/modelLoader/HiDreamLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class HiDreamLoRAModelLoader( @@ -33,6 +34,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> HiDreamModel | None: base_model_loader = HiDreamModelLoader() lora_model_loader = HiDreamLoRALoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/HunyuanVideoEmbeddingModelLoader.py b/modules/modelLoader/HunyuanVideoEmbeddingModelLoader.py index d2e28c9b7..eeebf2f98 100644 --- a/modules/modelLoader/HunyuanVideoEmbeddingModelLoader.py +++ b/modules/modelLoader/HunyuanVideoEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class HunyuanVideoEmbeddingModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> HunyuanVideoModel | None: base_model_loader = HunyuanVideoModelLoader() embedding_loader = HunyuanVideoEmbeddingLoader() @@ -41,7 +43,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/HunyuanVideoFineTuneModelLoader.py b/modules/modelLoader/HunyuanVideoFineTuneModelLoader.py index b2ae057ff..f81c8f94b 100644 --- a/modules/modelLoader/HunyuanVideoFineTuneModelLoader.py +++ b/modules/modelLoader/HunyuanVideoFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class HunyuanVideoFineTuneModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> HunyuanVideoModel | None: base_model_loader = HunyuanVideoModelLoader() embedding_loader = HunyuanVideoEmbeddingLoader() @@ -41,7 +43,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/HunyuanVideoLoRAModelLoader.py b/modules/modelLoader/HunyuanVideoLoRAModelLoader.py index 2aa4555c2..a88b8d836 100644 --- a/modules/modelLoader/HunyuanVideoLoRAModelLoader.py +++ b/modules/modelLoader/HunyuanVideoLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class HunyuanVideoLoRAModelLoader( @@ -33,6 +34,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> HunyuanVideoModel | None: base_model_loader = HunyuanVideoModelLoader() lora_model_loader = HunyuanVideoLoRALoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/PixArtAlphaEmbeddingModelLoader.py b/modules/modelLoader/PixArtAlphaEmbeddingModelLoader.py index 52b00ca26..1126589b1 100644 --- a/modules/modelLoader/PixArtAlphaEmbeddingModelLoader.py +++ b/modules/modelLoader/PixArtAlphaEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class PixArtAlphaEmbeddingModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> PixArtAlphaModel | None: base_model_loader = PixArtAlphaModelLoader() embedding_loader = PixArtAlphaEmbeddingLoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/PixArtAlphaFineTuneModelLoader.py b/modules/modelLoader/PixArtAlphaFineTuneModelLoader.py index 45c52c331..1f2ff7de8 100644 --- a/modules/modelLoader/PixArtAlphaFineTuneModelLoader.py +++ b/modules/modelLoader/PixArtAlphaFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class PixArtAlphaFineTuneModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> PixArtAlphaModel | None: base_model_loader = PixArtAlphaModelLoader() embedding_loader = PixArtAlphaEmbeddingLoader() @@ -43,7 +45,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/PixArtAlphaLoRAModelLoader.py b/modules/modelLoader/PixArtAlphaLoRAModelLoader.py index 028f7289f..27b42e96f 100644 --- a/modules/modelLoader/PixArtAlphaLoRAModelLoader.py +++ b/modules/modelLoader/PixArtAlphaLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class PixArtAlphaLoRAModelLoader( @@ -35,6 +36,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> PixArtAlphaModel | None: base_model_loader = PixArtAlphaModelLoader() lora_model_loader = PixArtAlphaLoRALoader() @@ -45,7 +47,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/QwenFineTuneModelLoader.py b/modules/modelLoader/QwenFineTuneModelLoader.py index 0172d2b79..3e6d89709 100644 --- a/modules/modelLoader/QwenFineTuneModelLoader.py +++ b/modules/modelLoader/QwenFineTuneModelLoader.py @@ -6,6 +6,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class QwenFineTuneModelLoader( @@ -31,6 +32,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> QwenModel | None: base_model_loader = QwenModelLoader() model = QwenModel(model_type=model_type) @@ -38,6 +40,6 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) return model diff --git a/modules/modelLoader/QwenLoRAModelLoader.py b/modules/modelLoader/QwenLoRAModelLoader.py index 41003cd07..3ae11114c 100644 --- a/modules/modelLoader/QwenLoRAModelLoader.py +++ b/modules/modelLoader/QwenLoRAModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class QwenLoRAModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> QwenModel | None: base_model_loader = QwenModelLoader() lora_model_loader = QwenLoRALoader() @@ -41,7 +43,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) return model diff --git a/modules/modelLoader/SanaEmbeddingModelLoader.py b/modules/modelLoader/SanaEmbeddingModelLoader.py index 6781d2ba9..7215fe7c7 100644 --- a/modules/modelLoader/SanaEmbeddingModelLoader.py +++ b/modules/modelLoader/SanaEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class SanaEmbeddingModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> SanaModel | None: base_model_loader = SanaModelLoader() embedding_loader = SanaEmbeddingLoader() @@ -41,7 +43,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/SanaFineTuneModelLoader.py b/modules/modelLoader/SanaFineTuneModelLoader.py index 82b14e49c..a713cb555 100644 --- a/modules/modelLoader/SanaFineTuneModelLoader.py +++ b/modules/modelLoader/SanaFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class SanaFineTuneModelLoader( @@ -32,6 +33,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> SanaModel | None: base_model_loader = SanaModelLoader() embedding_loader = SanaEmbeddingLoader() @@ -41,7 +43,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/SanaLoRAModelLoader.py b/modules/modelLoader/SanaLoRAModelLoader.py index bfaf8a1da..f326cb6fa 100644 --- a/modules/modelLoader/SanaLoRAModelLoader.py +++ b/modules/modelLoader/SanaLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class SanaLoRAModelLoader( @@ -33,6 +34,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> SanaModel | None: base_model_loader = SanaModelLoader() lora_model_loader = SanaLoRALoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/StableDiffusion3EmbeddingModelLoader.py b/modules/modelLoader/StableDiffusion3EmbeddingModelLoader.py index 86d0ff771..4cdbddff2 100644 --- a/modules/modelLoader/StableDiffusion3EmbeddingModelLoader.py +++ b/modules/modelLoader/StableDiffusion3EmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusion3EmbeddingModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusion3Model | None: base_model_loader = StableDiffusion3ModelLoader() embedding_loader = StableDiffusion3EmbeddingLoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/StableDiffusion3FineTuneModelLoader.py b/modules/modelLoader/StableDiffusion3FineTuneModelLoader.py index 8a9ee4477..a3fd19add 100644 --- a/modules/modelLoader/StableDiffusion3FineTuneModelLoader.py +++ b/modules/modelLoader/StableDiffusion3FineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusion3FineTuneModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusion3Model | None: base_model_loader = StableDiffusion3ModelLoader() embedding_loader = StableDiffusion3EmbeddingLoader() @@ -43,7 +45,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/StableDiffusion3LoRAModelLoader.py b/modules/modelLoader/StableDiffusion3LoRAModelLoader.py index 994d87518..34a53024d 100644 --- a/modules/modelLoader/StableDiffusion3LoRAModelLoader.py +++ b/modules/modelLoader/StableDiffusion3LoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusion3LoRAModelLoader( @@ -35,6 +36,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusion3Model | None: base_model_loader = StableDiffusion3ModelLoader() lora_model_loader = StableDiffusion3LoRALoader() @@ -45,7 +47,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/StableDiffusionEmbeddingModelLoader.py b/modules/modelLoader/StableDiffusionEmbeddingModelLoader.py index e698eb3d5..d81bb673d 100644 --- a/modules/modelLoader/StableDiffusionEmbeddingModelLoader.py +++ b/modules/modelLoader/StableDiffusionEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusionEmbeddingModelLoader( @@ -46,6 +47,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusionModel | None: base_model_loader = StableDiffusionModelLoader() embedding_loader = StableDiffusionEmbeddingLoader() @@ -55,7 +57,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/StableDiffusionFineTuneModelLoader.py b/modules/modelLoader/StableDiffusionFineTuneModelLoader.py index acfdefef0..adc5448fc 100644 --- a/modules/modelLoader/StableDiffusionFineTuneModelLoader.py +++ b/modules/modelLoader/StableDiffusionFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusionFineTuneModelLoader( @@ -46,6 +47,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusionModel | None: base_model_loader = StableDiffusionModelLoader() embedding_loader = StableDiffusionEmbeddingLoader() @@ -55,7 +57,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/StableDiffusionLoRAModelLoader.py b/modules/modelLoader/StableDiffusionLoRAModelLoader.py index ff6503aa1..958a684b5 100644 --- a/modules/modelLoader/StableDiffusionLoRAModelLoader.py +++ b/modules/modelLoader/StableDiffusionLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusionLoRAModelLoader( @@ -47,6 +48,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusionModel | None: base_model_loader = StableDiffusionModelLoader() lora_model_loader = StableDiffusionLoRALoader() @@ -57,7 +59,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/StableDiffusionXLEmbeddingModelLoader.py b/modules/modelLoader/StableDiffusionXLEmbeddingModelLoader.py index 201e993cb..bec654875 100644 --- a/modules/modelLoader/StableDiffusionXLEmbeddingModelLoader.py +++ b/modules/modelLoader/StableDiffusionXLEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusionXLEmbeddingModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusionXLModel | None: base_model_loader = StableDiffusionXLModelLoader() embedding_loader = StableDiffusionXLEmbeddingLoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/StableDiffusionXLFineTuneModelLoader.py b/modules/modelLoader/StableDiffusionXLFineTuneModelLoader.py index 1e1efdc30..2f7e09ea3 100644 --- a/modules/modelLoader/StableDiffusionXLFineTuneModelLoader.py +++ b/modules/modelLoader/StableDiffusionXLFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusionXLFineTuneModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusionXLModel | None: base_model_loader = StableDiffusionXLModelLoader() embedding_loader = StableDiffusionXLEmbeddingLoader() @@ -43,7 +45,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/StableDiffusionXLLoRAModelLoader.py b/modules/modelLoader/StableDiffusionXLLoRAModelLoader.py index 15192bd81..d84cb7a2d 100644 --- a/modules/modelLoader/StableDiffusionXLLoRAModelLoader.py +++ b/modules/modelLoader/StableDiffusionXLLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class StableDiffusionXLLoRAModelLoader( @@ -35,6 +36,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> StableDiffusionXLModel | None: base_model_loader = StableDiffusionXLModelLoader() lora_model_loader = StableDiffusionXLLoRALoader() @@ -45,7 +47,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/WuerstchenEmbeddingModelLoader.py b/modules/modelLoader/WuerstchenEmbeddingModelLoader.py index e10e7b6de..3cfe1b624 100644 --- a/modules/modelLoader/WuerstchenEmbeddingModelLoader.py +++ b/modules/modelLoader/WuerstchenEmbeddingModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class WuerstchenEmbeddingModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> WuerstchenModel | None: base_model_loader = WuerstchenModelLoader() embedding_loader = WuerstchenEmbeddingLoader() @@ -43,7 +45,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.embedding.model_name, model_names) return model diff --git a/modules/modelLoader/WuerstchenFineTuneModelLoader.py b/modules/modelLoader/WuerstchenFineTuneModelLoader.py index 0182b31a0..df8ac40b1 100644 --- a/modules/modelLoader/WuerstchenFineTuneModelLoader.py +++ b/modules/modelLoader/WuerstchenFineTuneModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class WuerstchenFineTuneModelLoader( @@ -34,6 +35,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> WuerstchenModel | None: base_model_loader = WuerstchenModelLoader() embedding_loader = WuerstchenEmbeddingLoader() @@ -43,7 +45,7 @@ def load( self._load_internal_data(model, model_names.base_model) model.model_spec = self._load_default_model_spec(model_type) - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) embedding_loader.load(model, model_names.base_model, model_names) return model diff --git a/modules/modelLoader/WuerstchenLoRAModelLoader.py b/modules/modelLoader/WuerstchenLoRAModelLoader.py index fcd9f91ab..51d46760f 100644 --- a/modules/modelLoader/WuerstchenLoRAModelLoader.py +++ b/modules/modelLoader/WuerstchenLoRAModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter class WuerstchenLoRAModelLoader( @@ -35,6 +36,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> WuerstchenModel | None: base_model_loader = WuerstchenModelLoader() lora_model_loader = WuerstchenLoRALoader() @@ -45,7 +47,7 @@ def load( model.model_spec = self._load_default_model_spec(model_type) if model_names.base_model is not None: - base_model_loader.load(model, model_type, model_names, weight_dtypes) + base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters) lora_model_loader.load(model, model_names) embedding_loader.load(model, model_names.lora, model_names) diff --git a/modules/modelLoader/chroma/ChromaModelLoader.py b/modules/modelLoader/chroma/ChromaModelLoader.py index 1424e566a..9c3cf326e 100644 --- a/modules/modelLoader/chroma/ChromaModelLoader.py +++ b/modules/modelLoader/chroma/ChromaModelLoader.py @@ -3,10 +3,10 @@ from modules.model.ChromaModel import ChromaModel from modules.modelLoader.mixin.HFModelLoaderMixin import HFModelLoaderMixin -from modules.util.enum.DataType import DataType from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter import torch @@ -33,10 +33,11 @@ def __load_internal( base_model_name: str, transformer_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): self.__load_diffusers( - model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name, + model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name, quant_filters, ) else: raise Exception("not an internal model") @@ -49,6 +50,7 @@ def __load_diffusers( base_model_name: str, transformer_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): diffusers_sub = [] if not transformer_model_name: @@ -101,10 +103,10 @@ def __load_diffusers( transformer_model_name, #avoid loading the transformer in float32: torch_dtype = torch.bfloat16 if weight_dtypes.transformer.torch_dtype() is None else weight_dtypes.transformer.torch_dtype(), - quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer == DataType.GGUF else None, + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer.is_gguf() else None, ) transformer = self._convert_diffusers_sub_module_to_dtype( - transformer, weight_dtypes.transformer, weight_dtypes.train_dtype + transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters, ) else: transformer = self._load_diffusers_sub_module( @@ -113,6 +115,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "transformer", + quant_filters, ) model.model_type = model_type @@ -130,6 +133,7 @@ def __load_safetensors( base_model_name: str, transformer_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): #no single file .safetensors for Chroma available at the time of writing this code raise NotImplementedError("Loading of single file Chroma models not supported. Use the diffusers model instead. Optionally, transformer-only safetensor files can be loaded by overriding the transformer.") @@ -140,12 +144,13 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] try: self.__load_internal( - model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters, ) return except Exception: @@ -153,7 +158,7 @@ def load( try: self.__load_diffusers( - model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters, ) return except Exception: @@ -161,7 +166,7 @@ def load( try: self.__load_safetensors( - model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters, ) return except Exception: diff --git a/modules/modelLoader/flux/FluxModelLoader.py b/modules/modelLoader/flux/FluxModelLoader.py index 254f9a3ee..fbdc2012f 100644 --- a/modules/modelLoader/flux/FluxModelLoader.py +++ b/modules/modelLoader/flux/FluxModelLoader.py @@ -3,10 +3,10 @@ from modules.model.FluxModel import FluxModel from modules.modelLoader.mixin.HFModelLoaderMixin import HFModelLoaderMixin -from modules.util.enum.DataType import DataType from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter import torch @@ -36,11 +36,12 @@ def __load_internal( vae_model_name: str, include_text_encoder_1: bool, include_text_encoder_2: bool, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): self.__load_diffusers( model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name, - include_text_encoder_1, include_text_encoder_2, + include_text_encoder_1, include_text_encoder_2, quant_filters, ) else: raise Exception("not an internal model") @@ -55,6 +56,7 @@ def __load_diffusers( vae_model_name: str, include_text_encoder_1: bool, include_text_encoder_2: bool, + quant_filters: list[ModuleFilter], ): diffusers_sub = [] transformers_sub = [] @@ -137,10 +139,10 @@ def __load_diffusers( transformer_model_name, #avoid loading the transformer in float32: torch_dtype = torch.bfloat16 if weight_dtypes.transformer.torch_dtype() is None else weight_dtypes.transformer.torch_dtype(), - quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer == DataType.GGUF else None, + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer.is_gguf() else None, ) transformer = self._convert_diffusers_sub_module_to_dtype( - transformer, weight_dtypes.transformer, weight_dtypes.train_dtype + transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters, ) else: transformer = self._load_diffusers_sub_module( @@ -149,6 +151,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "transformer", + quant_filters, ) model.model_type = model_type @@ -170,6 +173,7 @@ def __load_safetensors( vae_model_name: str, include_text_encoder_1: bool, include_text_encoder_2: bool, + quant_filters: list[ModuleFilter], ): transformer = FluxTransformer2DModel.from_single_file( #always load transformer separately even though FluxPipeLine.from_single_file() could load it, to avoid loading in float32: @@ -222,7 +226,7 @@ def __load_safetensors( print("text encoder 2 (t5) not loaded, continuing without it") transformer = self._convert_diffusers_sub_module_to_dtype( - pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype + pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters, ) model.model_type = model_type @@ -240,13 +244,14 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] try: self.__load_internal( model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, - model_names.include_text_encoder, model_names.include_text_encoder_2, + model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters, ) return except Exception: @@ -255,7 +260,7 @@ def load( try: self.__load_diffusers( model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, - model_names.include_text_encoder, model_names.include_text_encoder_2, + model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters, ) return except Exception: @@ -264,7 +269,7 @@ def load( try: self.__load_safetensors( model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, - model_names.include_text_encoder, model_names.include_text_encoder_2, + model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters, ) return except Exception: diff --git a/modules/modelLoader/hiDream/HiDreamModelLoader.py b/modules/modelLoader/hiDream/HiDreamModelLoader.py index 92bce0a9a..e6e5ca462 100644 --- a/modules/modelLoader/hiDream/HiDreamModelLoader.py +++ b/modules/modelLoader/hiDream/HiDreamModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter from diffusers import ( AutoencoderKL, @@ -42,11 +43,12 @@ def __load_internal( include_text_encoder_2: bool, include_text_encoder_3: bool, include_text_encoder_4: bool, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): self.__load_diffusers( model, model_type, weight_dtypes, base_model_name, text_encoder_4_model_name, vae_model_name, - include_text_encoder_1, include_text_encoder_2, include_text_encoder_3, include_text_encoder_4, + include_text_encoder_1, include_text_encoder_2, include_text_encoder_3, include_text_encoder_4, quant_filters, ) else: raise Exception("not an internal model") @@ -63,6 +65,7 @@ def __load_diffusers( include_text_encoder_2: bool, include_text_encoder_3: bool, include_text_encoder_4: bool, + quant_filters: list[ModuleFilter], ): diffusers_sub = [] transformers_sub = [] @@ -191,6 +194,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "transformer", + quant_filters, ) model.model_type = model_type @@ -218,6 +222,7 @@ def __load_safetensors( include_text_encoder_2: bool, include_text_encoder_3: bool, include_text_encoder_4: bool, + quant_filters: list[ModuleFilter], ): pipeline = HiDreamImagePipeline.from_single_file( pretrained_model_link_or_path=base_model_name, @@ -264,7 +269,7 @@ def __load_safetensors( print("text encoder 2 (t5) not loaded, continuing without it") transformer = self._convert_diffusers_sub_module_to_dtype( - pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype + pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters, ) model.model_type = model_type @@ -290,6 +295,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] @@ -298,7 +304,7 @@ def load( model, model_type, weight_dtypes, model_names.base_model, model_names.text_encoder_4, model_names.vae_model, model_names.include_text_encoder, model_names.include_text_encoder_2, - model_names.include_text_encoder_3, model_names.include_text_encoder_4, + model_names.include_text_encoder_3, model_names.include_text_encoder_4, quant_filters, ) self.__after_load(model) return @@ -310,7 +316,7 @@ def load( model, model_type, weight_dtypes, model_names.base_model, model_names.text_encoder_4, model_names.vae_model, model_names.include_text_encoder, model_names.include_text_encoder_2, - model_names.include_text_encoder_3, model_names.include_text_encoder_4, + model_names.include_text_encoder_3, model_names.include_text_encoder_4, quant_filters, ) self.__after_load(model) return @@ -322,7 +328,7 @@ def load( model, model_type, weight_dtypes, model_names.base_model, model_names.text_encoder_4, model_names.vae_model, model_names.include_text_encoder, model_names.include_text_encoder_2, - model_names.include_text_encoder_3, model_names.include_text_encoder_4, + model_names.include_text_encoder_3, model_names.include_text_encoder_4, quant_filters, ) self.__after_load(model) return diff --git a/modules/modelLoader/hunyuanVideo/HunyuanVideoModelLoader.py b/modules/modelLoader/hunyuanVideo/HunyuanVideoModelLoader.py index fe044ca03..e7e8996fb 100644 --- a/modules/modelLoader/hunyuanVideo/HunyuanVideoModelLoader.py +++ b/modules/modelLoader/hunyuanVideo/HunyuanVideoModelLoader.py @@ -7,6 +7,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter from diffusers import ( AutoencoderKLHunyuanVideo, @@ -32,11 +33,12 @@ def __load_internal( vae_model_name: str, include_text_encoder_1: bool, include_text_encoder_2: bool, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): self.__load_diffusers( model, model_type, weight_dtypes, base_model_name, vae_model_name, - include_text_encoder_1, include_text_encoder_2, + include_text_encoder_1, include_text_encoder_2, quant_filters, ) else: raise Exception("not an internal model") @@ -50,6 +52,7 @@ def __load_diffusers( vae_model_name: str, include_text_encoder_1: bool, include_text_encoder_2: bool, + quant_filters: list[ModuleFilter], ): diffusers_sub = [] transformers_sub = [] @@ -133,6 +136,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "transformer", + quant_filters, ) model.model_type = model_type @@ -153,6 +157,7 @@ def __load_safetensors( vae_model_name: str, include_text_encoder_1: bool, include_text_encoder_2: bool, + quant_filters: list[ModuleFilter], ): pipeline = HunyuanVideoPipeline.from_single_file( pretrained_model_link_or_path=base_model_name, @@ -192,7 +197,7 @@ def __load_safetensors( print("text encoder 2 (clip l) not loaded, continuing without it") transformer = self._convert_diffusers_sub_module_to_dtype( - pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype + pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters, ) model.model_type = model_type @@ -214,13 +219,14 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] try: self.__load_internal( model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, - model_names.include_text_encoder, model_names.include_text_encoder_2, + model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters, ) self.__after_load(model) return @@ -230,7 +236,7 @@ def load( try: self.__load_diffusers( model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, - model_names.include_text_encoder, model_names.include_text_encoder_2, + model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters, ) self.__after_load(model) return @@ -240,7 +246,7 @@ def load( try: self.__load_safetensors( model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, - model_names.include_text_encoder, model_names.include_text_encoder_2, + model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters, ) self.__after_load(model) return diff --git a/modules/modelLoader/mixin/HFModelLoaderMixin.py b/modules/modelLoader/mixin/HFModelLoaderMixin.py index 163702417..f5e48b057 100644 --- a/modules/modelLoader/mixin/HFModelLoaderMixin.py +++ b/modules/modelLoader/mixin/HFModelLoaderMixin.py @@ -6,11 +6,10 @@ from itertools import repeat from modules.util.enum.DataType import DataType +from modules.util.ModuleFilter import ModuleFilter from modules.util.quantization_util import ( is_quantized_parameter, - replace_linear_with_fp8_layers, - replace_linear_with_int8_layers, - replace_linear_with_nf4_layers, + replace_linear_with_quantized_layers, ) import torch @@ -32,6 +31,7 @@ def __load_sub_module( dtype: DataType, train_dtype: DataType, keep_in_fp32_modules: list[str] | None, + quant_filters: list[ModuleFilter] | None, pretrained_model_name_or_path: str, subfolder: str | None, model_filename: str, @@ -42,12 +42,7 @@ def __load_sub_module( keep_in_fp32_modules = [] with accelerate.init_empty_weights(): - if dtype.quantize_nf4(): - replace_linear_with_nf4_layers(sub_module, keep_in_fp32_modules, copy_parameters=False) - elif dtype.quantize_int8(): - replace_linear_with_int8_layers(sub_module, keep_in_fp32_modules, copy_parameters=False) - elif dtype.quantize_fp8(): - replace_linear_with_fp8_layers(sub_module, keep_in_fp32_modules, copy_parameters=False) + replace_linear_with_quantized_layers(sub_module, dtype, keep_in_fp32_modules, quant_filters, copy_parameters=False) is_local = os.path.isdir(pretrained_model_name_or_path) @@ -120,10 +115,6 @@ def __load_sub_module( if hasattr(sub_module, '_fix_state_dict_keys_on_load'): sub_module._fix_state_dict_keys_on_load(state_dict) - #TODO why is it necessary to iterate by key names from the state dict? - #why not iterate through the object model, like replace_linear_... does? - #would avoid key replacements as follows. - if hasattr(sub_module, "_checkpoint_conversion_mapping"): #required for loading the text encoder of Qwen new_state_dict = {} for k, v in state_dict.items(): @@ -133,6 +124,10 @@ def __load_sub_module( new_state_dict[new_k] = v state_dict = new_state_dict + #this loads the actual data from the state dict into tensors that are 'meta' tensors up to this point + #tensors that will be quantized are loaded at their original dtype. non-quantized tensors are converted + #to their intended dtype here + #TODO why not quantize here? would avoid to load the entire model first (high RAM) and then quantize (low RAM) for key, value in state_dict.items(): module = sub_module tensor_name = key @@ -195,6 +190,7 @@ def _load_transformers_sub_module( dtype=dtype, train_dtype=train_dtype, keep_in_fp32_modules=module_type._keep_in_fp32_modules, + quant_filters=None, pretrained_model_name_or_path=pretrained_model_name_or_path, subfolder=subfolder, model_filename="model.safetensors", @@ -209,6 +205,7 @@ def _load_diffusers_sub_module( train_dtype: DataType, pretrained_model_name_or_path: str, subfolder: str | None = None, + quant_filters: list[ModuleFilter] | None = None, ): user_agent = { "file_type": "model", @@ -230,6 +227,7 @@ def _load_diffusers_sub_module( dtype=dtype, train_dtype=train_dtype, keep_in_fp32_modules=module_type._keep_in_fp32_modules, + quant_filters=quant_filters, pretrained_model_name_or_path=pretrained_model_name_or_path, subfolder=subfolder, model_filename="diffusion_pytorch_model.safetensors", @@ -243,21 +241,16 @@ def __convert_sub_module_to_dtype( dtype: DataType, train_dtype: DataType, keep_in_fp32_modules: list[str] | None, + quant_filters: list[ModuleFilter] | None, ): if keep_in_fp32_modules is None: keep_in_fp32_modules = [] - if dtype.quantize_nf4(): - replace_linear_with_nf4_layers(sub_module, keep_in_fp32_modules, copy_parameters=True) - elif dtype.quantize_int8(): - replace_linear_with_int8_layers(sub_module, keep_in_fp32_modules, copy_parameters=True) - elif dtype.quantize_fp8(): - replace_linear_with_fp8_layers(sub_module, keep_in_fp32_modules, copy_parameters=True) + replace_linear_with_quantized_layers(sub_module, dtype, keep_in_fp32_modules, quant_filters, copy_parameters=True) for module_name, module in sub_module.named_modules(): param_iter = [(x, y[0], y[1]) for x, y in zip(repeat(False), module._parameters.items(), strict=False)] buffer_iter = [(x, y[0], y[1]) for x, y in zip(repeat(True), module._buffers.items(), strict=False)] - for is_buffer, tensor_name, value in param_iter + buffer_iter: if value is not None and torch.is_floating_point(value): old_type = type(value) @@ -281,6 +274,7 @@ def _convert_transformers_sub_module_to_dtype( sub_module: nn.Module, dtype: DataType, train_dtype: DataType, + quant_filters: list[ModuleFilter] | None = None, ): module_type = type(sub_module) @@ -289,6 +283,7 @@ def _convert_transformers_sub_module_to_dtype( dtype, train_dtype, module_type._keep_in_fp32_modules, + quant_filters, ) def _convert_diffusers_sub_module_to_dtype( @@ -296,12 +291,14 @@ def _convert_diffusers_sub_module_to_dtype( sub_module: nn.Module, dtype: DataType, train_dtype: DataType, + quant_filters: list[ModuleFilter] | None = None, ): return self.__convert_sub_module_to_dtype( sub_module, dtype, train_dtype, None, + quant_filters, ) def _prepare_sub_modules(self, pretrained_model_name_or_path: str, diffusers_modules: list[str], transformers_modules: list[str]): diff --git a/modules/modelLoader/pixartAlpha/PixArtAlphaModelLoader.py b/modules/modelLoader/pixartAlpha/PixArtAlphaModelLoader.py index 1428a4079..aef761608 100644 --- a/modules/modelLoader/pixartAlpha/PixArtAlphaModelLoader.py +++ b/modules/modelLoader/pixartAlpha/PixArtAlphaModelLoader.py @@ -6,6 +6,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter from diffusers import AutoencoderKL, DDIMScheduler, Transformer2DModel from transformers import T5EncoderModel, T5Tokenizer @@ -24,9 +25,10 @@ def __load_internal( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): - self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, vae_model_name) + self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, vae_model_name, quant_filters) else: raise Exception("not an internal model") @@ -37,6 +39,7 @@ def __load_diffusers( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): tokenizer = T5Tokenizer.from_pretrained( base_model_name, @@ -78,6 +81,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "transformer", + quant_filters, ) model.model_type = model_type @@ -93,19 +97,20 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> PixArtAlphaModel | None: stacktraces = [] base_model_name = model_names.base_model try: - self.__load_internal(model, model_type, weight_dtypes, base_model_name, model_names.vae_model) + self.__load_internal(model, model_type, weight_dtypes, base_model_name, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) try: - self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, model_names.vae_model) + self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) diff --git a/modules/modelLoader/qwen/QwenModelLoader.py b/modules/modelLoader/qwen/QwenModelLoader.py index 85b5817f1..633b0f8ed 100644 --- a/modules/modelLoader/qwen/QwenModelLoader.py +++ b/modules/modelLoader/qwen/QwenModelLoader.py @@ -3,10 +3,10 @@ from modules.model.QwenModel import QwenModel from modules.modelLoader.mixin.HFModelLoaderMixin import HFModelLoaderMixin -from modules.util.enum.DataType import DataType from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter import torch @@ -33,10 +33,11 @@ def __load_internal( base_model_name: str, transformer_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): self.__load_diffusers( - model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name, + model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name, quant_filters, ) else: raise Exception("not an internal model") @@ -49,6 +50,7 @@ def __load_diffusers( base_model_name: str, transformer_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): diffusers_sub = [] if not transformer_model_name: @@ -103,10 +105,10 @@ def __load_diffusers( subfolder="transformer", #avoid loading the transformer in float32: torch_dtype = torch.bfloat16 if weight_dtypes.transformer.torch_dtype() is None else weight_dtypes.transformer.torch_dtype(), - quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer == DataType.GGUF else None, + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer.is_gguf() else None, ) transformer = self._convert_diffusers_sub_module_to_dtype( - transformer, weight_dtypes.transformer, weight_dtypes.train_dtype + transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters, ) else: transformer = self._load_diffusers_sub_module( @@ -115,6 +117,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "transformer", + quant_filters, ) model.model_type = model_type @@ -132,6 +135,7 @@ def __load_safetensors( base_model_name: str, transformer_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): #no single file .safetensors for Qwen available at the time of writing this code raise NotImplementedError("Loading of single file Qwen models not supported. Use the diffusers model instead. Optionally, transformer-only safetensor files can be loaded by overriding the transformer.") @@ -142,12 +146,13 @@ def load( #TODO share code between models model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] try: self.__load_internal( - model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters, ) return except Exception: @@ -155,7 +160,7 @@ def load( #TODO share code between models try: self.__load_diffusers( - model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters, ) return except Exception: @@ -163,7 +168,7 @@ def load( #TODO share code between models try: self.__load_safetensors( - model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters, ) return except Exception: diff --git a/modules/modelLoader/sana/SanaModelLoader.py b/modules/modelLoader/sana/SanaModelLoader.py index e4c2d638f..0c2b5fc4e 100644 --- a/modules/modelLoader/sana/SanaModelLoader.py +++ b/modules/modelLoader/sana/SanaModelLoader.py @@ -6,6 +6,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter from diffusers import AutoencoderDC, DPMSolverMultistepScheduler, SanaTransformer2DModel from transformers import Gemma2Model, GemmaTokenizer @@ -24,9 +25,10 @@ def __load_internal( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): - self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, vae_model_name) + self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, vae_model_name, quant_filters) else: raise Exception("not an internal model") @@ -37,6 +39,7 @@ def __load_diffusers( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): tokenizer = GemmaTokenizer.from_pretrained( base_model_name, @@ -78,6 +81,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "transformer", + quant_filters, ) model.model_type = model_type @@ -93,19 +97,20 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ) -> SanaModel | None: stacktraces = [] base_model_name = model_names.base_model try: - self.__load_internal(model, model_type, weight_dtypes, base_model_name, model_names.vae_model) + self.__load_internal(model, model_type, weight_dtypes, base_model_name, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) try: - self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, model_names.vae_model) + self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) diff --git a/modules/modelLoader/stableDiffusion/StableDiffusionModelLoader.py b/modules/modelLoader/stableDiffusion/StableDiffusionModelLoader.py index 2e55c8b5a..6aee93334 100644 --- a/modules/modelLoader/stableDiffusion/StableDiffusionModelLoader.py +++ b/modules/modelLoader/stableDiffusion/StableDiffusionModelLoader.py @@ -9,6 +9,7 @@ from modules.util.enum.NoiseScheduler import NoiseScheduler from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter import torch @@ -55,9 +56,10 @@ def __load_internal( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): - self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, vae_model_name) + self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, vae_model_name, quant_filters) else: raise Exception("not an internal model") @@ -68,6 +70,7 @@ def __load_diffusers( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): tokenizer = CLIPTokenizer.from_pretrained( base_model_name, @@ -113,6 +116,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "unet", + quant_filters, ) image_depth_processor = DPTImageProcessor.from_pretrained( @@ -158,6 +162,7 @@ def __load_ckpt( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): state_dict = torch.load(base_model_name, weights_only=True) state_dict = self.__fix_nai_model(state_dict) @@ -196,7 +201,7 @@ def __load_ckpt( pipeline.text_encoder, weight_dtypes.text_encoder, weight_dtypes.train_dtype ) unet = self._convert_diffusers_sub_module_to_dtype( - pipeline.unet, weight_dtypes.unet, weight_dtypes.train_dtype + pipeline.unet, weight_dtypes.unet, weight_dtypes.train_dtype, quant_filters, ) model.model_type = model_type @@ -215,6 +220,7 @@ def __load_safetensors( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): num_in_channels = 4 if model_type.has_mask_input(): @@ -251,7 +257,7 @@ def __load_safetensors( pipeline.text_encoder, weight_dtypes.text_encoder, weight_dtypes.train_dtype ) unet = self._convert_diffusers_sub_module_to_dtype( - pipeline.unet, weight_dtypes.unet, weight_dtypes.train_dtype + pipeline.unet, weight_dtypes.unet, weight_dtypes.train_dtype, quant_filters, ) model.model_type = model_type @@ -269,6 +275,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] @@ -276,19 +283,19 @@ def load( model.sd_config_filename = self._get_sd_config_name(model_type, model_names.base_model) try: - self.__load_internal(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model) + self.__load_internal(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) try: - self.__load_diffusers(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model) + self.__load_diffusers(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) try: - self.__load_safetensors(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model) + self.__load_safetensors(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) diff --git a/modules/modelLoader/stableDiffusion3/StableDiffusion3ModelLoader.py b/modules/modelLoader/stableDiffusion3/StableDiffusion3ModelLoader.py index f5fa6871f..0852b20a9 100644 --- a/modules/modelLoader/stableDiffusion3/StableDiffusion3ModelLoader.py +++ b/modules/modelLoader/stableDiffusion3/StableDiffusion3ModelLoader.py @@ -6,6 +6,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5Tokenizer @@ -27,11 +28,12 @@ def __load_internal( include_text_encoder_1: bool, include_text_encoder_2: bool, include_text_encoder_3: bool, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): self.__load_diffusers( model, model_type, weight_dtypes, base_model_name, vae_model_name, - include_text_encoder_1, include_text_encoder_2, include_text_encoder_3, + include_text_encoder_1, include_text_encoder_2, include_text_encoder_3, quant_filters, ) else: raise Exception("not an internal model") @@ -46,6 +48,7 @@ def __load_diffusers( include_text_encoder_1: bool, include_text_encoder_2: bool, include_text_encoder_3: bool, + quant_filters: list[ModuleFilter], ): #no call to self._prepare_sub_modules, because SAI polluted their sd3 / sd3.5 medium repo text encoders with fp16 files @@ -133,6 +136,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "transformer", + quant_filters, ) model.model_type = model_type @@ -156,6 +160,7 @@ def __load_safetensors( include_text_encoder_1: bool, include_text_encoder_2: bool, include_text_encoder_3: bool, + quant_filters: list[ModuleFilter], ): pipeline = StableDiffusion3Pipeline.from_single_file( pretrained_model_link_or_path=base_model_name, @@ -220,7 +225,7 @@ def __load_safetensors( print("text encoder 3 (t5) not loaded, continuing without it") transformer = self._convert_diffusers_sub_module_to_dtype( - pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype + pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters, ) model.model_type = model_type @@ -240,6 +245,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] @@ -247,7 +253,7 @@ def load( self.__load_internal( model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, model_names.include_text_encoder, model_names.include_text_encoder_2, - model_names.include_text_encoder_3, + model_names.include_text_encoder_3, quant_filters, ) return except Exception: @@ -257,7 +263,7 @@ def load( self.__load_diffusers( model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, model_names.include_text_encoder, model_names.include_text_encoder_2, - model_names.include_text_encoder_3, + model_names.include_text_encoder_3, quant_filters, ) return except Exception: @@ -267,7 +273,7 @@ def load( self.__load_safetensors( model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, model_names.include_text_encoder, model_names.include_text_encoder_2, - model_names.include_text_encoder_3, + model_names.include_text_encoder_3, quant_filters, ) return except Exception: diff --git a/modules/modelLoader/stableDiffusionXL/StableDiffusionXLModelLoader.py b/modules/modelLoader/stableDiffusionXL/StableDiffusionXLModelLoader.py index 75bf33d98..4ffb61777 100644 --- a/modules/modelLoader/stableDiffusionXL/StableDiffusionXLModelLoader.py +++ b/modules/modelLoader/stableDiffusionXL/StableDiffusionXLModelLoader.py @@ -9,6 +9,7 @@ from modules.util.enum.NoiseScheduler import NoiseScheduler from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter from diffusers import ( AutoencoderKL, @@ -46,9 +47,10 @@ def __load_internal( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(base_model_name, "meta.json")): - self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, vae_model_name) + self.__load_diffusers(model, model_type, weight_dtypes, base_model_name, vae_model_name, quant_filters) else: raise Exception("not an internal model") @@ -59,6 +61,7 @@ def __load_diffusers( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): tokenizer_1 = CLIPTokenizer.from_pretrained( base_model_name, @@ -117,6 +120,7 @@ def __load_diffusers( weight_dtypes.train_dtype, base_model_name, "unet", + quant_filters, ) model.model_type = model_type @@ -135,7 +139,11 @@ def __load_ckpt( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): + if quant_filters is not None and len(quant_filters) > 0: + raise NotImplementedError("Quantization not implemented for loading ckpt files") + pipeline = StableDiffusionXLPipeline.from_single_file( pretrained_model_link_or_path=base_model_name, original_config=model.sd_config_filename, @@ -176,6 +184,7 @@ def __load_safetensors( weight_dtypes: ModelWeightDtypes, base_model_name: str, vae_model_name: str, + quant_filters: list[ModuleFilter], ): if model_type.has_conditioning_image_input(): pipeline = StableDiffusionXLInpaintPipeline.from_single_file( @@ -216,7 +225,7 @@ def __load_safetensors( pipeline.text_encoder_2, weight_dtypes.text_encoder_2, weight_dtypes.train_dtype ) unet = self._convert_diffusers_sub_module_to_dtype( - pipeline.unet, weight_dtypes.unet, weight_dtypes.train_dtype + pipeline.unet, weight_dtypes.unet, weight_dtypes.train_dtype, quant_filters, ) model.model_type = model_type @@ -234,6 +243,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] @@ -241,19 +251,19 @@ def load( model.sd_config_filename = self._get_sd_config_name(model_type, model_names.base_model) try: - self.__load_internal(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model) + self.__load_internal(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) try: - self.__load_diffusers(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model) + self.__load_diffusers(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) try: - self.__load_safetensors(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model) + self.__load_safetensors(model, model_type, weight_dtypes, model_names.base_model, model_names.vae_model, quant_filters) return except Exception: stacktraces.append(traceback.format_exc()) diff --git a/modules/modelLoader/wuerstchen/WuerstchenModelLoader.py b/modules/modelLoader/wuerstchen/WuerstchenModelLoader.py index c7c237cc7..d3419706f 100644 --- a/modules/modelLoader/wuerstchen/WuerstchenModelLoader.py +++ b/modules/modelLoader/wuerstchen/WuerstchenModelLoader.py @@ -8,6 +8,7 @@ from modules.util.enum.ModelType import ModelType from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes +from modules.util.ModuleFilter import ModuleFilter from diffusers import DDPMWuerstchenScheduler from diffusers.models import StableCascadeUNet @@ -32,6 +33,7 @@ def __load_internal( prior_model_name: str, effnet_encoder_model_name: str, decoder_model_name: str, + quant_filters: list[ModuleFilter], ): if os.path.isfile(os.path.join(prior_model_name, "meta.json")): self.__load_diffusers( @@ -42,6 +44,7 @@ def __load_internal( "", # pass an empty prior name, so it's always loaded from the backup effnet_encoder_model_name, decoder_model_name, + quant_filters, ) else: raise Exception("not an internal model") @@ -55,6 +58,7 @@ def __load_diffusers( prior_prior_model_name: str, effnet_encoder_model_name: str, decoder_model_name: str, + quant_filters: list[ModuleFilter], ): if model_type.is_wuerstchen_v2(): decoder_tokenizer = CLIPTokenizer.from_pretrained( @@ -127,6 +131,7 @@ def __load_diffusers( weight_dtypes.train_dtype, prior_model_name, "prior", + quant_filters, ) elif model_type.is_stable_cascade(): if prior_prior_model_name: @@ -140,7 +145,7 @@ def __load_diffusers( prior_prior = StableCascadeUNet(**prior_config) prior_prior.load_state_dict(convert_stable_cascade_ckpt_to_diffusers(load_file(prior_prior_model_name))) prior_prior = self._convert_diffusers_sub_module_to_dtype( - prior_prior, weight_dtypes.prior, weight_dtypes.fallback_train_dtype + prior_prior, weight_dtypes.prior, weight_dtypes.fallback_train_dtype, quant_filters, ) else: prior_prior = self._load_diffusers_sub_module( @@ -149,6 +154,7 @@ def __load_diffusers( weight_dtypes.fallback_train_dtype, prior_model_name, "prior", + quant_filters, ) prior_tokenizer = CLIPTokenizer.from_pretrained( @@ -196,6 +202,7 @@ def load( model_type: ModelType, model_names: ModelNames, weight_dtypes: ModelWeightDtypes, + quant_filters: list[ModuleFilter] | None = None, ): stacktraces = [] @@ -211,7 +218,7 @@ def load( weight_dtypes, prior_model_name, effnet_encoder_model_name, - decoder_model_name, + decoder_model_name, quant_filters, ) return except Exception: @@ -225,7 +232,7 @@ def load( prior_model_name, prior_prior_model_name, effnet_encoder_model_name, - decoder_model_name, + decoder_model_name, quant_filters, ) return except Exception: diff --git a/modules/modelSetup/BaseChromaSetup.py b/modules/modelSetup/BaseChromaSetup.py index 62e60d2d3..7a7847df7 100644 --- a/modules/modelSetup/BaseChromaSetup.py +++ b/modules/modelSetup/BaseChromaSetup.py @@ -82,9 +82,9 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.transformer, self.train_device, model.train_dtype) + quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.train_dtype, config) def _setup_embeddings( self, @@ -225,7 +225,6 @@ def predict( ) packed_latent_input = model.pack_latents(latent_input) - image_seq_len = packed_latent_input.shape[1] image_attention_mask = torch.full((packed_latent_input.shape[0], image_seq_len), True, dtype=torch.bool, device=text_attention_mask.device) attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) if not torch.all(text_attention_mask) else None diff --git a/modules/modelSetup/BaseFluxSetup.py b/modules/modelSetup/BaseFluxSetup.py index 487bce323..e524c96bf 100644 --- a/modules/modelSetup/BaseFluxSetup.py +++ b/modules/modelSetup/BaseFluxSetup.py @@ -84,10 +84,10 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype) - quantize_layers(model.text_encoder_2, self.train_device, model.text_encoder_2_train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.transformer, self.train_device, model.train_dtype) + quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype, config) + quantize_layers(model.text_encoder_2, self.train_device, model.text_encoder_2_train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.train_dtype, config) def _setup_embeddings( self, diff --git a/modules/modelSetup/BaseHiDreamSetup.py b/modules/modelSetup/BaseHiDreamSetup.py index c4bf06e94..48e691ecd 100644 --- a/modules/modelSetup/BaseHiDreamSetup.py +++ b/modules/modelSetup/BaseHiDreamSetup.py @@ -98,12 +98,12 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype) - quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype) - quantize_layers(model.text_encoder_3, self.train_device, model.text_encoder_3_train_dtype) - quantize_layers(model.text_encoder_4, self.train_device, model.train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.transformer, self.train_device, model.transformer_train_dtype) + quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype, config) + quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype, config) + quantize_layers(model.text_encoder_3, self.train_device, model.text_encoder_3_train_dtype, config) + quantize_layers(model.text_encoder_4, self.train_device, model.train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.transformer_train_dtype, config) def _setup_embeddings( self, diff --git a/modules/modelSetup/BaseHunyuanVideoSetup.py b/modules/modelSetup/BaseHunyuanVideoSetup.py index 628d090e9..bbb90b71a 100644 --- a/modules/modelSetup/BaseHunyuanVideoSetup.py +++ b/modules/modelSetup/BaseHunyuanVideoSetup.py @@ -85,10 +85,10 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype) - quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.transformer, self.train_device, model.transformer_train_dtype) + quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype, config) + quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.transformer_train_dtype, config) model.vae.enable_tiling() diff --git a/modules/modelSetup/BasePixArtAlphaSetup.py b/modules/modelSetup/BasePixArtAlphaSetup.py index 6d4ef738c..8240fb5f4 100644 --- a/modules/modelSetup/BasePixArtAlphaSetup.py +++ b/modules/modelSetup/BasePixArtAlphaSetup.py @@ -82,9 +82,9 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.transformer, self.train_device, model.train_dtype) + quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.train_dtype, config) def _setup_embeddings( self, diff --git a/modules/modelSetup/BaseQwenSetup.py b/modules/modelSetup/BaseQwenSetup.py index ad0978bad..dc7115274 100644 --- a/modules/modelSetup/BaseQwenSetup.py +++ b/modules/modelSetup/BaseQwenSetup.py @@ -77,9 +77,9 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.transformer, self.train_device, model.train_dtype) + quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.train_dtype, config) def predict( self, diff --git a/modules/modelSetup/BaseSanaSetup.py b/modules/modelSetup/BaseSanaSetup.py index 782f3dd76..84078ff6f 100644 --- a/modules/modelSetup/BaseSanaSetup.py +++ b/modules/modelSetup/BaseSanaSetup.py @@ -94,9 +94,9 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.transformer, self.train_device, model.train_dtype) + quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.train_dtype, config) def _setup_embeddings( self, diff --git a/modules/modelSetup/BaseStableDiffusion3Setup.py b/modules/modelSetup/BaseStableDiffusion3Setup.py index 728f09a27..5015b21af 100644 --- a/modules/modelSetup/BaseStableDiffusion3Setup.py +++ b/modules/modelSetup/BaseStableDiffusion3Setup.py @@ -86,11 +86,11 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype) - quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype) - quantize_layers(model.text_encoder_3, self.train_device, model.text_encoder_3_train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.transformer, self.train_device, model.train_dtype) + quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype, config) + quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype, config) + quantize_layers(model.text_encoder_3, self.train_device, model.text_encoder_3_train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.train_dtype, config) def _setup_embeddings( self, diff --git a/modules/modelSetup/BaseStableDiffusionSetup.py b/modules/modelSetup/BaseStableDiffusionSetup.py index 799b44255..0fc6ed0df 100644 --- a/modules/modelSetup/BaseStableDiffusionSetup.py +++ b/modules/modelSetup/BaseStableDiffusionSetup.py @@ -68,9 +68,9 @@ def setup_optimizations( config.weight_dtypes().embedding if config.train_any_embedding() else None, ], config.enable_autocast_cache) - quantize_layers(model.text_encoder, self.train_device, model.train_dtype) - quantize_layers(model.vae, self.train_device, model.train_dtype) - quantize_layers(model.unet, self.train_device, model.train_dtype) + quantize_layers(model.text_encoder, self.train_device, model.train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.unet, self.train_device, model.train_dtype, config) def _setup_embeddings( self, diff --git a/modules/modelSetup/BaseStableDiffusionXLSetup.py b/modules/modelSetup/BaseStableDiffusionXLSetup.py index 79b4edeec..37121951a 100644 --- a/modules/modelSetup/BaseStableDiffusionXLSetup.py +++ b/modules/modelSetup/BaseStableDiffusionXLSetup.py @@ -76,10 +76,10 @@ def setup_optimizations( config.enable_autocast_cache, ) - quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype) - quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype) - quantize_layers(model.vae, self.train_device, model.vae_train_dtype) - quantize_layers(model.unet, self.train_device, model.train_dtype) + quantize_layers(model.text_encoder_1, self.train_device, model.train_dtype, config) + quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype, config) + quantize_layers(model.vae, self.train_device, model.vae_train_dtype, config) + quantize_layers(model.unet, self.train_device, model.train_dtype, config) def _setup_embeddings( self, diff --git a/modules/modelSetup/BaseWuerstchenSetup.py b/modules/modelSetup/BaseWuerstchenSetup.py index 973919de4..23b3440a5 100644 --- a/modules/modelSetup/BaseWuerstchenSetup.py +++ b/modules/modelSetup/BaseWuerstchenSetup.py @@ -104,12 +104,12 @@ def setup_optimizations( ) if model.model_type.is_wuerstchen_v2(): - quantize_layers(model.decoder_text_encoder, self.train_device, model.train_dtype) - quantize_layers(model.decoder_decoder, self.train_device, model.train_dtype) - quantize_layers(model.decoder_vqgan, self.train_device, model.train_dtype) - quantize_layers(model.effnet_encoder, self.train_device, model.effnet_encoder_train_dtype) - quantize_layers(model.prior_text_encoder, self.train_device, model.train_dtype) - quantize_layers(model.prior_prior, self.train_device, model.prior_train_dtype) + quantize_layers(model.decoder_text_encoder, self.train_device, model.train_dtype, config) + quantize_layers(model.decoder_decoder, self.train_device, model.train_dtype, config) + quantize_layers(model.decoder_vqgan, self.train_device, model.train_dtype, config) + quantize_layers(model.effnet_encoder, self.train_device, model.effnet_encoder_train_dtype, config) + quantize_layers(model.prior_text_encoder, self.train_device, model.train_dtype, config) + quantize_layers(model.prior_prior, self.train_device, model.prior_train_dtype, config) def _setup_embeddings( self, diff --git a/modules/module/quantized/LinearA8.py b/modules/module/quantized/LinearA8.py new file mode 100644 index 000000000..c10734b47 --- /dev/null +++ b/modules/module/quantized/LinearA8.py @@ -0,0 +1,151 @@ +from modules.util.quantization_util import ( + quantize_fp8_axiswise, + quantize_int8_axiswise, +) +from modules.util.triton_mm_8bit import mm_8bit as triton_mm_8bit + +import torch +from torch import Tensor, nn + + +def int8_forward_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> Tensor: + x_8, x_scale = quantize_int8_axiswise(x, dim=-1) + w_8, w_scale = quantize_int8_axiswise(weight, dim=-1) + res = torch._int_mm(x_8, w_8.T) + res_scaled = res.to(x.dtype).mul_(w_scale.T).mul_(x_scale) + if bias is not None: + res_scaled.add_(bias.to(x.dtype)) + return res_scaled + +def fp8_forward_axiswise(x: Tensor, weight: Tensor, bias: Tensor=None) -> Tensor: + x_8, x_scale = quantize_fp8_axiswise(x, dim=-1) + w_8, w_scale = quantize_fp8_axiswise(weight, dim=-1) + one = torch.ones(1, device=x.device) + res = torch._scaled_mm(x_8, w_8.T, scale_a=one, scale_b=one, out_dtype=x.dtype) + res_scaled = res.mul_(w_scale.T).mul_(x_scale) + if bias is not None: + res_scaled.add_(bias.to(x.dtype)) + return res_scaled + +def int8_backward_act_axiswise(output: Tensor, weight: Tensor) -> Tensor: + output_8, output_scale = quantize_int8_axiswise(output, dim=-1) + w_8, w_scale = quantize_int8_axiswise(weight, dim=0) + #almost always, grad outputs are already contiguous and this is a no-op. But there are some grad outputs from SDXL that are non-contiguous: + output_8 = output_8.contiguous() + mm_res = triton_mm_8bit(output_8, w_8) + return mm_res.to(output.dtype).mul_(w_scale).mul_(output_scale) + +def fp8_backward_act_axiswise(output: Tensor, weight: Tensor) -> Tensor: + output_8, output_scale = quantize_fp8_axiswise(output, dim=-1) + w_8, w_scale = quantize_fp8_axiswise(weight, dim=0) + mm_res = triton_mm_8bit(output_8.contiguous(), w_8) + return mm_res.to(output.dtype).mul_(w_scale).mul_(output_scale) + +def int8_backward_weight_axiswise(output: Tensor, x: Tensor) -> Tensor: + output_8, output_scale = quantize_int8_axiswise(output, dim=0) + x_8, x_scale = quantize_int8_axiswise(x, dim=0) + #TODO could be more efficient using a kernel that accepts a non-contiguous lhs matrix + mm_res = triton_mm_8bit(output_8.T.contiguous(), x_8) + return mm_res.to(x.dtype).mul_(output_scale.T).mul_(x_scale) + +def fp8_backward_weight_axiswise(output: Tensor, x: Tensor) -> Tensor: + output_8, output_scale = quantize_fp8_axiswise(output, dim=0) + x_8, x_scale = quantize_fp8_axiswise(x, dim=0) + mm_res = triton_mm_8bit(output_8.T.contiguous(), x_8) + return mm_res.to(x.dtype).mul_(output_scale.T).mul_(x_scale) + +class LinearInt8Function(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: + ctx.save_for_backward(x, weight) + #axiswise performs better than tensorwise in tests, even though + #it requires another quant during backward - but quant is cheap + + # x @ weight.T + bias + return int8_forward_axiswise(x, weight, bias) + + @staticmethod + def backward(ctx, grad_output: Tensor): + x, weight = ctx.saved_tensors + + grad_x, grad_weight, grad_bias = None, None, None + if ctx.needs_input_grad[0]: + # grad_output @ weight + grad_x = int8_backward_act_axiswise(grad_output, weight) + if ctx.needs_input_grad[1]: + # grad_output.T @ x + grad_weight = int8_backward_weight_axiswise(grad_output, x) + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + return grad_x, grad_weight, grad_bias + +class LinearFp8Function(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: + ctx.save_for_backward(x, weight) + return fp8_forward_axiswise(x, weight, bias) + + @staticmethod + def backward(ctx, grad_output: Tensor): + x, weight = ctx.saved_tensors + + grad_x, grad_weight, grad_bias = None, None, None + if ctx.needs_input_grad[0]: + # grad_output @ weight + grad_x = fp8_backward_act_axiswise(grad_output, weight) + if ctx.needs_input_grad[1]: + # grad_output.T @ x + grad_weight = fp8_backward_weight_axiswise(grad_output, x) + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + return grad_x, grad_weight, grad_bias + + +class LinearA8(nn.Linear): + def __init__(self, dtype, *args, **kwargs): + super().__init__(*args, **kwargs) + + assert dtype in [torch.int8, torch.float8_e4m3fn] + self._dtype = dtype + def forward(self, x_orig: torch.Tensor) -> torch.Tensor: + x = x_orig.to(self.weight.dtype).reshape(-1, x_orig.shape[-1]) + if x.shape[0] > 16: + if self._dtype == torch.int8: + y = LinearInt8Function.apply(x, self.weight, self.bias) + else: + y = LinearFp8Function.apply(x, self.weight, self.bias) + return y.reshape(x_orig.shape[:-1] + (y.shape[-1], )) + else: + return super().forward(x_orig) + + + +def run_benchmark(fn, desc, steps=10000, warmup=500, compile=False): + if compile: + fn = torch.compile(fn, fullgraph=True) + from tqdm import tqdm + for _ in range(warmup): + fn() + torch.cuda.synchronize() + for _ in tqdm(range(steps), desc=desc): + fn() + torch.cuda.synchronize() + + +@torch.no_grad() +def benchmark(m, k, n, device = 'cuda'): + output = torch.randn(m, n, device=device, dtype=torch.bfloat16) + x = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + + run_benchmark(lambda: int8_forward_axiswise(x, weight), "forward int8", compile=True) + run_benchmark(lambda: int8_backward_weight_axiswise(output, x), "backward weight int8", compile=True) + run_benchmark(lambda: fp8_forward_axiswise(x, weight), "forward fp8", compile=True) + run_benchmark(lambda: fp8_backward_weight_axiswise(output, x), "backward weight fp8", compile=True) + + +if __name__ == "__main__": + benchmark(2 * 1024 + 50, 3072, 3072 + 16) + #benchmark_fp8(2080, 3072) diff --git a/modules/module/quantized/LinearFp8.py b/modules/module/quantized/LinearFp8.py index f0ad404d5..5a54b2f38 100644 --- a/modules/module/quantized/LinearFp8.py +++ b/modules/module/quantized/LinearFp8.py @@ -19,7 +19,6 @@ def __init__(self, *args, **kwargs): self.fp8_dtype = torch.float8_e4m3fn self._scale = torch.tensor(1.0, dtype=torch.float) self.register_buffer("scale", self._scale) - self.compute_dtype = None def original_weight_shape(self) -> tuple[int, ...]: @@ -31,7 +30,7 @@ def unquantized_weight(self, dtype: torch.dtype, device: torch.device) -> torch. else: return self.weight.detach().to(dtype=dtype) - def quantize(self, device: torch.device | None = None): + def quantize(self, device: torch.device | None = None, **kwargs): if self.is_quantized: return self.is_quantized = True diff --git a/modules/module/quantized/LinearGGUFA8.py b/modules/module/quantized/LinearGGUFA8.py new file mode 100644 index 000000000..1a85eeb3f --- /dev/null +++ b/modules/module/quantized/LinearGGUFA8.py @@ -0,0 +1,69 @@ +from modules.module.quantized.LinearA8 import ( + fp8_backward_act_axiswise, + fp8_forward_axiswise, + int8_backward_act_axiswise, + int8_forward_axiswise, +) + +import torch +from torch import Tensor + +from diffusers.quantizers.gguf.utils import GGUFLinear, dequantize_gguf_tensor + +import gguf + +UNQUANTIZED_TYPES = [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16] + +class LinearGGUFIntA8RequantFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: + ctx.save_for_backward(weight) + #axiswise performs better than tensorwise in tests, even though + #it requires another requant during backward - but requant is cheap + return int8_forward_axiswise(x, weight, bias) + + @staticmethod + def backward(ctx, output: Tensor): + if ctx.needs_input_grad != (True, False, False): + raise NotImplementedError("GGUF cannot be used for full finetuning") + weight, = ctx.saved_tensors + return int8_backward_act_axiswise(output, weight), None, None + +class LinearGGUFFpA8RequantFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: + ctx.save_for_backward(weight) + return fp8_forward_axiswise(x, weight, bias) + + @staticmethod + def backward(ctx, output: Tensor): + if ctx.needs_input_grad != (True, False, False): + raise NotImplementedError("GGUF cannot be used for full finetuning") + weight, = ctx.saved_tensors + return fp8_backward_act_axiswise(output, weight), None, None + +class LinearGGUFA8(GGUFLinear): + def __init__(self, dtype: torch.dtype, *args, **kwargs): + super().__init__(*args, **kwargs) + + assert dtype in [torch.int8, torch.float8_e4m3fn] + self._dtype = dtype + + def forward(self, x_orig: torch.Tensor) -> torch.Tensor: + assert not self.weight.requires_grad + x = x_orig.to(self.compute_dtype).reshape(-1, x_orig.shape[-1]) + w = dequantize_gguf_tensor(self.weight).detach() + + if x.shape[0] > 16 and self.weight.quant_type not in UNQUANTIZED_TYPES: + if self._dtype == torch.int8: + y = LinearGGUFIntA8RequantFunction.apply(x, w, self.bias) + else: + y = LinearGGUFFpA8RequantFunction.apply(x, w, self.bias) + else: + x = x.to(self.compute_dtype) + w = w.to(self.compute_dtype) + bias = self.bias.to(self.compute_dtype) if self.bias is not None else None + y = torch.nn.functional.linear(x, w, bias) + + assert y.dtype == self.compute_dtype + return y.reshape(x_orig.shape[:-1] + (y.shape[-1], )) diff --git a/modules/module/quantized/LinearNf4.py b/modules/module/quantized/LinearNf4.py index 2a4bfbf17..2cbc46bc2 100644 --- a/modules/module/quantized/LinearNf4.py +++ b/modules/module/quantized/LinearNf4.py @@ -62,7 +62,7 @@ def unquantized_weight(self, dtype: torch.dtype, device: torch.device) -> torch else: return self.weight.detach().to(dtype=dtype) - def quantize(self, device: torch.device | None = None): + def quantize(self, device: torch.device | None = None, **kwargs): if self.is_quantized: return self.is_quantized = True diff --git a/modules/module/quantized/LinearSVD.py b/modules/module/quantized/LinearSVD.py new file mode 100644 index 000000000..085c39374 --- /dev/null +++ b/modules/module/quantized/LinearSVD.py @@ -0,0 +1,98 @@ +from contextlib import suppress + +from modules.module.quantized.mixin.QuantizedLinearMixin import QuantizedLinearMixin +from modules.module.quantized.mixin.QuantizedModuleMixin import QuantizedModuleMixin + +import torch + + +class BaseLinearSVD( + QuantizedModuleMixin, + QuantizedLinearMixin, +): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +def _get_tensor_hash(t: torch.Tensor) -> str: + t = t.flatten().to(torch.float32) + vals = torch.stack([ + torch.sum(t), + torch.sum(t**2), + torch.sum(torch.sin(t)), + torch.sum(torch.cos(t)) + ]) + return vals.cpu().numpy().tobytes().hex() + +def make_svd_linear(linear_class): + class LinearSVD( + linear_class, + BaseLinearSVD, + ): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__svd_is_quantized = False + + #use parameters instead of buffer to allow offloading: + self.svd_up = torch.nn.Parameter(torch.empty(()), requires_grad=False) + self.svd_down = torch.nn.Parameter(torch.empty(()), requires_grad=False) + + def unquantized_weight(self, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + if self.__svd_is_quantized: + return (self.svd_up @ self.svd_down).to(dtype) + super().unquantized_weight(dtype, device) + else: + return super().unquantized_weight(dtype, device) + + @torch.no_grad() + def quantize(self, rank: int, svd_dtype: torch.dtype, device: torch.device | None = None, cache_dir: str | None = None, max_cache_rank: int = 128, **kwargs): + if self.__svd_is_quantized: + return + self.__svd_is_quantized = True + + W = super().unquantized_weight(torch.float32, device) + orig_device = W.device + if device is not None: + W = W.to(device=device) + + U = None + if cache_dir is not None: + filename = cache_dir + "/" + _get_tensor_hash(W) + ".pt" + with suppress(FileNotFoundError): + U, S, Vh = torch.load(filename, map_location=W.device) + + if U is None: + #use full svd - torch.svd_lowrank is not reducing the quant range nearly as much: + U, S, Vh = torch.linalg.svd(W, full_matrices=False) + + if cache_dir is not None: + torch.save(( + U[:, :max_cache_rank].clone(), + S[:max_cache_rank].clone(), + Vh[:max_cache_rank, :].clone(), + ), filename) + + U_r = U[:, :rank] + S_r = S[:rank] + Vh_r = Vh[:rank, :] + + svd_down = Vh_r.clone().contiguous().to(svd_dtype) + svd_up = (U_r * S_r.unsqueeze(0)).clone().contiguous().to(svd_dtype) + weight = (W - (svd_up @ svd_down)).to(dtype=self.weight.dtype) + + if device is not None: + weight = weight.to(device=orig_device) + svd_up = svd_up.to(device=orig_device) + svd_down = svd_down.to(device=orig_device) + + self.requires_grad_(False) + self.svd_up = torch.nn.Parameter(svd_up, requires_grad=False) + self.svd_down = torch.nn.Parameter(svd_down, requires_grad=False) + self.weight.data = weight + super().quantize(device=device, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert self.__svd_is_quantized + assert not self.svd_down.requires_grad and not self.svd_up.requires_grad + return ((x @ self.svd_down.T) @ self.svd_up.T).to(x.dtype) + super().forward(x) + + return LinearSVD diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py new file mode 100644 index 000000000..c58f6e34e --- /dev/null +++ b/modules/module/quantized/LinearW8A8.py @@ -0,0 +1,189 @@ + +from modules.module.quantized.mixin.QuantizedLinearMixin import QuantizedLinearMixin +from modules.module.quantized.mixin.QuantizedModuleMixin import QuantizedModuleMixin +from modules.util.quantization_util import ( + dequantize, + quantize_fp8_axiswise, + quantize_fp8_tensorwise, + quantize_int8_axiswise, + quantize_int8_tensorwise, +) +from modules.util.triton_mm_8bit import mm_8bit as triton_mm_8bit + +import torch +from torch import Tensor, nn + + +def int8_forward_tokenwise(x: Tensor, weight: Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: + x_8, x_scale = quantize_int8_axiswise(x, dim=-1) + res = torch._int_mm(x_8, weight.T) + res_scaled = res.to(x.dtype).mul_(weight_scale * x_scale) + if bias is not None: + res_scaled.add_(bias.to(x.dtype)) + return res_scaled + +def fp8_forward_tokenwise(x: Tensor, weight: Tensor, weight_scale: float, bias: Tensor=None) -> Tensor: + x_8, x_scale = quantize_fp8_axiswise(x, dim=-1) + one = torch.ones(1, device=x.device) + res = torch._scaled_mm(x_8, weight.T, scale_a=one, scale_b=weight_scale.float(), out_dtype=x.dtype) + res_scaled = res.mul_(x_scale) #much faster than scaled by _scaled_mm + if bias is not None: + res_scaled.add_(bias.to(x.dtype)) + return res_scaled + +def int8_backward_axiswise(output: Tensor, weight: Tensor, weight_scale: float) -> Tensor: + output_8, output_scale = quantize_int8_axiswise(output, dim=-1) + mm_res = triton_mm_8bit(output_8.contiguous(), weight) + return mm_res.to(output.dtype).mul_(weight_scale * output_scale) + +def fp8_backward_axiswise(output: Tensor, weight: Tensor, weight_scale: float) -> Tensor: + output_8, output_scale = quantize_fp8_axiswise(output, dim=-1) + mm_res = triton_mm_8bit(output_8.contiguous(), weight) + return mm_res.to(output.dtype).mul_(weight_scale * output_scale) + + +class LinearInt8Function(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, weight: Tensor, weight_scale: float, bias: Tensor | None) -> Tensor: + ctx.save_for_backward(weight, weight_scale) + return int8_forward_tokenwise(x, weight, weight_scale, bias) + + @staticmethod + def backward(ctx, output: Tensor): + if ctx.needs_input_grad != (True, False, False, False): + raise NotImplementedError("Int A8W8 cannot be used for full finetuning") + + weight, weight_scale = ctx.saved_tensors + return int8_backward_axiswise(output, weight, weight_scale), None, None, None + +class LinearFp8Function(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, weight: Tensor, weight_scale: float, bias: Tensor | None) -> Tensor: + ctx.save_for_backward(weight, weight_scale) + return fp8_forward_tokenwise(x, weight, weight_scale, bias) + + @staticmethod + def backward(ctx, output: Tensor): + if ctx.needs_input_grad != (True, False, False, False): + raise NotImplementedError("Float A8W8 cannot be used for full finetuning") + + weight, weight_scale = ctx.saved_tensors + return fp8_backward_axiswise(output, weight, weight_scale), None, None, None + +class LinearW8A8( + nn.Linear, + QuantizedModuleMixin, + QuantizedLinearMixin, +): + def __init__(self, dtype, compute_dtype, *args, **kwargs): + super().__init__(*args, **kwargs) + + assert dtype in [torch.int8, torch.float8_e4m3fn] + self._dtype = dtype + self._compute_dtype = compute_dtype + + self.__is_quantized = False + self.register_buffer("scale", torch.tensor(1.0, dtype=torch.float32)) + + def original_weight_shape(self) -> tuple[int, ...]: + return self.weight.shape + + def unquantized_weight(self, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + return dequantize(self.weight.detach(), self.scale, self._compute_dtype).to(dtype) + + @torch.no_grad() + def quantize(self, device: torch.device | None = None, **kwargs): + if self.__is_quantized: + return + self.__is_quantized = True + + weight = self.weight.detach() + orig_device = weight.device + if device is not None: + weight = weight.to(device=device) + if self._dtype == torch.int8: + weight, scale = quantize_int8_tensorwise(weight) + else: + weight, scale = quantize_fp8_tensorwise(weight) + + if device is not None: + weight = weight.to(device=orig_device) + + self.requires_grad_(False) + self.weight.data = weight + + self.scale.copy_(scale) + + def forward(self, x_orig: torch.Tensor) -> torch.Tensor: + assert not self.weight.requires_grad + assert self.__is_quantized + x = x_orig.to(self._compute_dtype).reshape(-1, x_orig.shape[-1]) + + if x.shape[0] > 16: + if self._dtype == torch.int8: + y = LinearInt8Function.apply(x, self.weight, self.scale, self.bias) + else: + y = LinearFp8Function.apply(x, self.weight, self.scale, self.bias) + else: + w = dequantize(self.weight, self.scale, compute_dtype=self._compute_dtype) + y = torch.nn.functional.linear(x, w, self.bias.to(self._compute_dtype)) + + assert y.dtype == self._compute_dtype + return y.reshape(x_orig.shape[:-1] + (y.shape[-1], )) + +def run_benchmark(fn, desc, steps=10000, warmup=500, compile=False): + if compile: + fn = torch.compile(fn, fullgraph=True) + from tqdm import tqdm + for _ in range(warmup): + fn() + torch.cuda.synchronize() + for _ in tqdm(range(steps), desc=desc): + fn() + torch.cuda.synchronize() + + +@torch.no_grad() +def benchmark_int8(m, k, n, device = 'cuda'): + x = torch.randn(m,k, device=device, dtype=torch.bfloat16) + x_8 = torch.ones (m,k, device=device, dtype=torch.int8) + y = torch.randn(m,n, device=device, dtype=torch.bfloat16) + y_8 = torch.ones (m,n, device=device, dtype=torch.int8) + w_8 = torch.ones (n,k, device=device, dtype=torch.int8) + w_scale = torch.ones(1, device=device) + + + run_benchmark(lambda: torch._int_mm(x_8, w_8.T), "torch mm int") + run_benchmark(lambda: triton_mm_8bit(x_8, w_8.T), "triton mm int") + def torch_backward(a, b): + torch._int_mm(a, b.T.contiguous().T) + run_benchmark(lambda: torch_backward(y_8, w_8), "torch mm backward int8") + run_benchmark(lambda: triton_mm_8bit(y_8, w_8), "triton mm backward int8") + + run_benchmark(lambda: int8_forward_tokenwise(x, w_8, w_scale), "torch forward int", compile=True) + run_benchmark(lambda: int8_backward_axiswise(y, w_8, w_scale), "triton backward int", compile=True) + + +@torch.no_grad() +def benchmark_fp8(m, k, n, device = 'cuda'): + x = torch.randn(m,k, device=device, dtype=torch.bfloat16) + x_8 = torch.ones (m,k, device=device, dtype=torch.float8_e4m3fn) + y = torch.randn(m,n, device=device, dtype=torch.bfloat16) + y_8 = torch.ones (m,n, device=device, dtype=torch.float8_e4m3fn) + w_8 = torch.ones (n,k, device=device, dtype=torch.float8_e4m3fn) + w_scale = torch.ones(1, device=device, dtype=torch.bfloat16) + one_scale = torch.ones(1, device=device) + + run_benchmark(lambda: torch._scaled_mm(x_8, w_8.T, out_dtype=torch.bfloat16, scale_a=one_scale.float(), scale_b=w_scale.float()), "torch mm fp8") + run_benchmark(lambda: triton_mm_8bit(x_8, w_8.T), "triton mm fp8") + def torch_backward(a, b): + torch._scaled_mm(a, b.T.contiguous().T, out_dtype=torch.bfloat16, scale_a=one_scale.float(), scale_b=w_scale.float()) + run_benchmark(lambda: torch_backward(y_8, w_8), "torch mm backward fp8") + run_benchmark(lambda: triton_mm_8bit(y_8, w_8), "triton mm backward fp8") + run_benchmark(lambda: fp8_forward_tokenwise(x, w_8, w_scale), "torch forward fp8", compile=True) + run_benchmark(lambda: fp8_backward_axiswise(y, w_8, w_scale), "triton backward fp8", compile=True) + + +if __name__ == "__main__": + benchmark_int8(2 * 1024 + 50, 3072, 3072 + 16) + benchmark_fp8(2 * 1024 + 50, 3072, 3072 + 16) diff --git a/modules/module/quantized/mixin/QuantizedModuleMixin.py b/modules/module/quantized/mixin/QuantizedModuleMixin.py index 54904f59e..c110fd7ef 100644 --- a/modules/module/quantized/mixin/QuantizedModuleMixin.py +++ b/modules/module/quantized/mixin/QuantizedModuleMixin.py @@ -5,5 +5,5 @@ class QuantizedModuleMixin(metaclass=ABCMeta): @abstractmethod - def quantize(self, device: torch.device | None = None): + def quantize(self, device: torch.device | None = None, **kwargs): pass diff --git a/modules/trainer/GenericTrainer.py b/modules/trainer/GenericTrainer.py index f10b5e99d..143287d67 100644 --- a/modules/trainer/GenericTrainer.py +++ b/modules/trainer/GenericTrainer.py @@ -29,6 +29,7 @@ from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.TimeUnit import TimeUnit from modules.util.enum.TrainingMethod import TrainingMethod +from modules.util.ModuleFilter import ModuleFilter from modules.util.profiling_util import TorchMemoryRecorder, TorchProfiler from modules.util.time_util import get_string_timestamp from modules.util.torch_util import torch_gc @@ -122,10 +123,15 @@ def start(self): self.callbacks.on_update_status("loading the model") + quant_filters = [ + ModuleFilter(pattern, use_regex=self.config.quantization_layer_filter_regex) + for pattern in self.config.quantization_layer_filter.split(",") + ] self.model = self.model_loader.load( model_type=self.config.model_type, model_names=model_names, weight_dtypes=self.config.weight_dtypes(), + quant_filters=quant_filters, ) self.model.train_config = self.config diff --git a/modules/ui/ConvertModelUI.py b/modules/ui/ConvertModelUI.py index 685aeaa85..837b27614 100644 --- a/modules/ui/ConvertModelUI.py +++ b/modules/ui/ConvertModelUI.py @@ -131,6 +131,7 @@ def convert_model(self): base_model=self.convert_model_args.input_name, ), weight_dtypes=self.convert_model_args.weight_dtypes(), + #TODO quantization layer filter ) elif self.convert_model_args.training_method in [TrainingMethod.LORA, TrainingMethod.EMBEDDING]: model = model_loader.load( @@ -140,6 +141,7 @@ def convert_model(self): embedding=EmbeddingName(str(uuid4()), self.convert_model_args.input_name), ), weight_dtypes=self.convert_model_args.weight_dtypes(), + #TODO quantization layer filter ) else: raise Exception("could not load model: " + self.convert_model_args.input_name) diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index 24c6b2c7f..d0cbf515b 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -1,5 +1,16 @@ from pathlib import Path +from modules.modelSetup.BaseChromaSetup import PRESETS as chroma_presets +from modules.modelSetup.BaseFluxSetup import PRESETS as flux_presets +from modules.modelSetup.BaseHiDreamSetup import PRESETS as hidream_presets +from modules.modelSetup.BaseHunyuanVideoSetup import PRESETS as hunyuan_video_presets +from modules.modelSetup.BasePixArtAlphaSetup import PRESETS as pixart_presets +from modules.modelSetup.BaseQwenSetup import PRESETS as qwen_presets +from modules.modelSetup.BaseSanaSetup import PRESETS as sana_presets +from modules.modelSetup.BaseStableDiffusion3Setup import PRESETS as sd3_presets +from modules.modelSetup.BaseStableDiffusionSetup import PRESETS as sd_presets +from modules.modelSetup.BaseStableDiffusionXLSetup import PRESETS as sdxl_presets +from modules.modelSetup.BaseWuerstchenSetup import PRESETS as sc_presets from modules.util.config.TrainConfig import TrainConfig from modules.util.enum.ConfigPart import ConfigPart from modules.util.enum.DataType import DataType @@ -33,46 +44,52 @@ def refresh_ui(self): self.scroll_frame = ctk.CTkScrollableFrame(self.master, fg_color="transparent") self.scroll_frame.grid(row=0, column=0, sticky="nsew") + self.scroll_frame.grid_columnconfigure(0, weight=1) - self.scroll_frame.grid_columnconfigure(0, weight=0) - self.scroll_frame.grid_columnconfigure(1, weight=10) - self.scroll_frame.grid_columnconfigure(2, minsize=50) - self.scroll_frame.grid_columnconfigure(3, weight=0) - self.scroll_frame.grid_columnconfigure(4, weight=1) + base_frame = ctk.CTkFrame(master=self.scroll_frame, corner_radius=5) + base_frame.grid(row=0, column=0, padx=5, pady=5, sticky="nsew") + + base_frame.grid_columnconfigure(0, weight=0) + base_frame.grid_columnconfigure(1, weight=10)#, minsize=500) + base_frame.grid_columnconfigure(2, minsize=50) + base_frame.grid_columnconfigure(3, weight=0) + base_frame.grid_columnconfigure(4, weight=1) if self.train_config.model_type.is_stable_diffusion(): #TODO simplify - self.__setup_stable_diffusion_ui() + self.__setup_stable_diffusion_ui(base_frame) if self.train_config.model_type.is_stable_diffusion_3(): - self.__setup_stable_diffusion_3_ui() + self.__setup_stable_diffusion_3_ui(base_frame) elif self.train_config.model_type.is_stable_diffusion_xl(): - self.__setup_stable_diffusion_xl_ui() + self.__setup_stable_diffusion_xl_ui(base_frame) elif self.train_config.model_type.is_wuerstchen(): - self.__setup_wuerstchen_ui() + self.__setup_wuerstchen_ui(base_frame) elif self.train_config.model_type.is_pixart(): - self.__setup_pixart_alpha_ui() + self.__setup_pixart_alpha_ui(base_frame) elif self.train_config.model_type.is_flux(): - self.__setup_flux_ui() + self.__setup_flux_ui(base_frame) elif self.train_config.model_type.is_chroma(): - self.__setup_chroma_ui() + self.__setup_chroma_ui(base_frame) elif self.train_config.model_type.is_qwen(): - self.__setup_qwen_ui() + self.__setup_qwen_ui(base_frame) elif self.train_config.model_type.is_sana(): - self.__setup_sana_ui() + self.__setup_sana_ui(base_frame) elif self.train_config.model_type.is_hunyuan_video(): - self.__setup_hunyuan_video_ui() + self.__setup_hunyuan_video_ui(base_frame) elif self.train_config.model_type.is_hi_dream(): - self.__setup_hi_dream_ui() + self.__setup_hi_dream_ui(base_frame) - def __setup_stable_diffusion_ui(self): + def __setup_stable_diffusion_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_unet=True, has_text_encoder=True, has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method in [ @@ -82,10 +99,11 @@ def __setup_stable_diffusion_ui(self): allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_stable_diffusion_3_ui(self): + def __setup_stable_diffusion_3_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_transformer=True, has_text_encoder_1=True, @@ -94,16 +112,18 @@ def __setup_stable_diffusion_3_ui(self): has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_flux_ui(self): + def __setup_flux_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_transformer=True, allow_override_transformer=True, @@ -112,16 +132,18 @@ def __setup_flux_ui(self): has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_chroma_ui(self): + def __setup_chroma_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_transformer=True, allow_override_transformer=True, @@ -129,16 +151,18 @@ def __setup_chroma_ui(self): has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_qwen_ui(self): + def __setup_qwen_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_transformer=True, allow_override_transformer=True, @@ -146,16 +170,18 @@ def __setup_qwen_ui(self): has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_stable_diffusion_xl_ui(self): + def __setup_stable_diffusion_xl_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_unet=True, has_text_encoder_1=True, @@ -163,24 +189,27 @@ def __setup_stable_diffusion_xl_ui(self): has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_wuerstchen_ui(self): + def __setup_wuerstchen_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_prior=True, allow_override_prior=self.train_config.model_type.is_stable_cascade(), has_text_encoder=True, ) - row = self.__create_effnet_encoder_components(row) - row = self.__create_decoder_components(row, self.train_config.model_type.is_wuerstchen_v2()) + row = self.__create_effnet_encoder_components(frame, row) + row = self.__create_decoder_components(frame, row, self.train_config.model_type.is_wuerstchen_v2()) row = self.__create_output_components( + frame, row, allow_safetensors=self.train_config.training_method != TrainingMethod.FINE_TUNE or self.train_config.model_type.is_stable_cascade(), @@ -188,42 +217,47 @@ def __setup_wuerstchen_ui(self): allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_pixart_alpha_ui(self): + def __setup_pixart_alpha_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_transformer=True, has_text_encoder=True, has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_sana_ui(self): + def __setup_sana_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_transformer=True, has_text_encoder=True, has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=self.train_config.training_method != TrainingMethod.FINE_TUNE, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_hunyuan_video_ui(self): + def __setup_hunyuan_video_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_transformer=True, has_text_encoder_1=True, @@ -231,16 +265,18 @@ def __setup_hunyuan_video_ui(self): has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __setup_hi_dream_ui(self): + def __setup_hi_dream_ui(self, frame): row = 0 - row = self.__create_base_dtype_components(row) + row = self.__create_base_dtype_components(frame, row) row = self.__create_base_components( + frame, row, has_transformer=True, has_text_encoder_1=True, @@ -251,52 +287,69 @@ def __setup_hi_dream_ui(self): has_vae=True, ) row = self.__create_output_components( + frame, row, allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) - def __create_dtype_options(self, include_none:bool=True, include_gguf=False) -> list[tuple[str, DataType]]: + def __create_dtype_options(self, include_none: bool=True, include_gguf: bool=False, include_svd: bool=False) -> list[tuple[str, DataType]]: options = [ ("float32", DataType.FLOAT_32), ("bfloat16", DataType.BFLOAT_16), ("float16", DataType.FLOAT_16), - ("float8", DataType.FLOAT_8), + ("float8 (W8)", DataType.FLOAT_8), + ("float W8A8", DataType.FLOAT_W8A8), + ("int W8A8", DataType.INT_W8A8), # ("int8", DataType.INT_8), # TODO: reactivate when the int8 implementation is fixed in bitsandbytes: https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1332 + ("bfloat16 A8 int", DataType.BFLOAT_16_A8_INT), + ("bfloat16 A8 float", DataType.BFLOAT_16_A8_FLOAT), + ("float16 A8 int", DataType.FLOAT_16_A8_INT), + ("float16 A8 float", DataType.FLOAT_16_A8_FLOAT), ("nfloat4", DataType.NFLOAT_4), ] + if include_svd: + options += [ + ("float8 (W8) SVDQuant", DataType.FLOAT_8_SVD), + ("float W8A8 SVDQuant", DataType.FLOAT_W8A8_SVD), + ("int W8A8 SVDQuant", DataType.INT_W8A8_SVD), + ("nfloat4 SVD", DataType.NFLOAT_4_SVD), + ] + if include_gguf: options.append(("GGUF", DataType.GGUF)) + options.append(("GGUF A8 float", DataType.GGUF_A8_FLOAT)) + options.append(("GGUF A8 int", DataType.GGUF_A8_INT)) if include_none: options.insert(0, ("", DataType.NONE)) return options - def __create_base_dtype_components(self, row: int) -> int: + def __create_base_dtype_components(self, frame, row: int) -> int: # huggingface token - components.label(self.scroll_frame, row, 0, "Hugging Face Token", + components.label(frame, row, 0, "Hugging Face Token", tooltip="Enter your Hugging Face access token if you have used a protected Hugging Face repository below.\nThis value is stored separately, not saved to your configuration file. " "Go to https://huggingface.co/settings/tokens to create an access token.", wide_tooltip=True) - components.entry(self.scroll_frame, row, 1, self.ui_state, "secrets.huggingface_token") + components.entry(frame, row, 1, self.ui_state, "secrets.huggingface_token") row += 1 # base model - components.label(self.scroll_frame, row, 0, "Base Model", + components.label(frame, row, 0, "Base Model", tooltip="Filename, directory or Hugging Face repository of the base model") components.file_entry( - self.scroll_frame, row, 1, self.ui_state, "base_model_name", + frame, row, 1, self.ui_state, "base_model_name", path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x ) # weight dtype - components.label(self.scroll_frame, row, 3, "Weight Data Type", + components.label(frame, row, 3, "Weight Data Type", tooltip="The base model weight data type used for training. This can reduce memory consumption, but reduces precision") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(False), + components.options_kv(frame, row, 4, self.__create_dtype_options(False), self.ui_state, "weight_dtype") row += 1 @@ -305,6 +358,7 @@ def __create_base_dtype_components(self, row: int) -> int: def __create_base_components( self, + frame, row: int, has_unet: bool = False, has_prior: bool = False, @@ -321,9 +375,9 @@ def __create_base_components( ) -> int: if has_unet: # unet weight dtype - components.label(self.scroll_frame, row, 3, "Override UNet Data Type", + components.label(frame, row, 3, "Override UNet Data Type", tooltip="Overrides the unet weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(include_svd=True), self.ui_state, "unet.weight_dtype") row += 1 @@ -331,17 +385,17 @@ def __create_base_components( if has_prior: if allow_override_prior: # prior model - components.label(self.scroll_frame, row, 0, "Prior Model", + components.label(frame, row, 0, "Prior Model", tooltip="Filename, directory or Hugging Face repository of the prior model") components.file_entry( - self.scroll_frame, row, 1, self.ui_state, "prior.model_name", + frame, row, 1, self.ui_state, "prior.model_name", path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x ) # prior weight dtype - components.label(self.scroll_frame, row, 3, "Override Prior Data Type", + components.label(frame, row, 3, "Override Prior Data Type", tooltip="Overrides the prior weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(include_svd=True), self.ui_state, "prior.weight_dtype") row += 1 @@ -349,60 +403,111 @@ def __create_base_components( if has_transformer: if allow_override_transformer: # transformer model - components.label(self.scroll_frame, row, 0, "Override Transformer / GGUF", + components.label(frame, row, 0, "Override Transformer / GGUF", tooltip="Can be used to override the transformer in the base model. Safetensors and GGUF files are supported, local and on Huggingface. If a GGUF file is used, the DataType must also be set to GGUF") components.file_entry( - self.scroll_frame, row, 1, self.ui_state, "transformer.model_name", + frame, row, 1, self.ui_state, "transformer.model_name", path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x ) # transformer weight dtype - components.label(self.scroll_frame, row, 3, "Override Transformer Data Type", + components.label(frame, row, 3, "Override Transformer Data Type", tooltip="Overrides the transformer weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(include_gguf=True), + components.options_kv(frame, row, 4, self.__create_dtype_options(include_svd=True, include_gguf=True), self.ui_state, "transformer.weight_dtype") row += 1 + presets = [] + if self.train_config.model_type.is_stable_diffusion(): #TODO simplify and de-duplicate with layer filter on training tab + presets = sd_presets + elif self.train_config.model_type.is_stable_diffusion_xl(): + presets = sdxl_presets + elif self.train_config.model_type.is_stable_diffusion_3(): + presets = sd3_presets + elif self.train_config.model_type.is_wuerstchen(): + presets = sc_presets + elif self.train_config.model_type.is_pixart(): + presets = pixart_presets + elif self.train_config.model_type.is_flux(): + presets = flux_presets + elif self.train_config.model_type.is_qwen(): + presets = qwen_presets + elif self.train_config.model_type.is_chroma(): + presets = chroma_presets + elif self.train_config.model_type.is_sana(): + presets = sana_presets + elif self.train_config.model_type.is_hunyuan_video(): + presets = hunyuan_video_presets + elif self.train_config.model_type.is_hi_dream(): + presets = hidream_presets + else: + presets = {"full": []} + + components.label(frame, row, 0, "Quantization") + components.layer_filter_entry(frame, row, 1, self.ui_state, + preset_var_name="quantization_layer_filter_preset", presets=presets, + preset_label="Layer Filter", + preset_tooltip="Select a preset defining which layers to quantize. Quantization of certain layers can decrease model quality. Only applies to the transformer/unet", + entry_var_name="quantization_layer_filter", + entry_tooltip="Comma-separated list of layers to quantize. Regular expressions (if toggled) are supported. Any model layer with a matching name will be quantized", + regex_var_name="quantization_layer_filter_regex", + regex_tooltip="If enabled, layer filter patterns are interpreted as regular expressions. Otherwise, simple substring matching is used.", + frame_color="transparent", + ) + # compile - components.label(self.scroll_frame, row, 3, "Compile transformer blocks", + components.label(frame, row, 3, "Compile transformer blocks", tooltip="Uses torch.compile and Triton to significantly speed up training. Only applies to transformer/unet. Disable in case of compatibility issues.") - components.switch(self.scroll_frame, row, 4, self.ui_state, "compile") + components.switch(frame, row, 4, self.ui_state, "compile") row += 1 + # SVDQuant + components.label(frame, row, 3, "SVDQuant Data Type", + tooltip="What datatype to use for SVDQuant weights decomposition.") + components.options_kv(frame, row, 4, [("float32", DataType.FLOAT_32), ("bfloat16", DataType.BFLOAT_16)], + self.ui_state, "svd_dtype") + row += 1 + + components.label(frame, row, 3, "SVDQuant Rank", + tooltip="Rank for SVDQuant weights decomposition") + components.entry(frame, row, 4, self.ui_state, "svd_rank") + row += 1 + + if has_text_encoder: # text encoder weight dtype - components.label(self.scroll_frame, row, 3, "Override Text Encoder Data Type", + components.label(frame, row, 3, "Override Text Encoder Data Type", tooltip="Overrides the text encoder weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "text_encoder.weight_dtype") row += 1 if has_text_encoder_1: # text encoder 1 weight dtype - components.label(self.scroll_frame, row, 3, "Override Text Encoder 1 Data Type", + components.label(frame, row, 3, "Override Text Encoder 1 Data Type", tooltip="Overrides the text encoder 1 weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "text_encoder.weight_dtype") row += 1 if has_text_encoder_2: # text encoder 2 weight dtype - components.label(self.scroll_frame, row, 3, "Override Text Encoder 2 Data Type", + components.label(frame, row, 3, "Override Text Encoder 2 Data Type", tooltip="Overrides the text encoder 2 weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "text_encoder_2.weight_dtype") row += 1 if has_text_encoder_3: # text encoder 3 weight dtype - components.label(self.scroll_frame, row, 3, "Override Text Encoder 3 Data Type", + components.label(frame, row, 3, "Override Text Encoder 3 Data Type", tooltip="Overrides the text encoder 3 weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "text_encoder_3.weight_dtype") row += 1 @@ -410,53 +515,53 @@ def __create_base_components( if has_text_encoder_4: if allow_override_text_encoder_4: # text encoder 4 weight dtype - components.label(self.scroll_frame, row, 0, "Text Encoder 4 Override", + components.label(frame, row, 0, "Text Encoder 4 Override", tooltip="Filename, directory or Hugging Face repository of the text encoder 4 model") components.file_entry( - self.scroll_frame, row, 1, self.ui_state, "text_encoder_4.model_name", + frame, row, 1, self.ui_state, "text_encoder_4.model_name", path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x ) # text encoder 4 weight dtype - components.label(self.scroll_frame, row, 3, "Override Text Encoder 4 Data Type", + components.label(frame, row, 3, "Override Text Encoder 4 Data Type", tooltip="Overrides the text encoder 4 weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "text_encoder_4.weight_dtype") row += 1 if has_vae: # base model - components.label(self.scroll_frame, row, 0, "VAE Override", + components.label(frame, row, 0, "VAE Override", tooltip="Directory or Hugging Face repository of a VAE model in diffusers format. Can be used to override the VAE included in the base model. Using a safetensor VAE file will cause an error that the model cannot be loaded.") components.file_entry( - self.scroll_frame, row, 1, self.ui_state, "vae.model_name", + frame, row, 1, self.ui_state, "vae.model_name", path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x ) # vae weight dtype - components.label(self.scroll_frame, row, 3, "Override VAE Data Type", + components.label(frame, row, 3, "Override VAE Data Type", tooltip="Overrides the vae weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "vae.weight_dtype") row += 1 return row - def __create_effnet_encoder_components(self, row: int): + def __create_effnet_encoder_components(self, frame, row: int): # effnet encoder model - components.label(self.scroll_frame, row, 0, "Effnet Encoder Model", + components.label(frame, row, 0, "Effnet Encoder Model", tooltip="Filename, directory or Hugging Face repository of the effnet encoder model") components.file_entry( - self.scroll_frame, row, 1, self.ui_state, "effnet_encoder.model_name", + frame, row, 1, self.ui_state, "effnet_encoder.model_name", path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x ) # effnet encoder weight dtype - components.label(self.scroll_frame, row, 3, "Override Effnet Encoder Data Type", + components.label(frame, row, 3, "Override Effnet Encoder Data Type", tooltip="Overrides the effnet encoder weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "effnet_encoder.weight_dtype") row += 1 @@ -465,38 +570,39 @@ def __create_effnet_encoder_components(self, row: int): def __create_decoder_components( self, + frame, row: int, has_text_encoder: bool, ) -> int: # decoder model - components.label(self.scroll_frame, row, 0, "Decoder Model", + components.label(frame, row, 0, "Decoder Model", tooltip="Filename, directory or Hugging Face repository of the decoder model") components.file_entry( - self.scroll_frame, row, 1, self.ui_state, "decoder.model_name", + frame, row, 1, self.ui_state, "decoder.model_name", path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x ) # decoder weight dtype - components.label(self.scroll_frame, row, 3, "Override Decoder Data Type", + components.label(frame, row, 3, "Override Decoder Data Type", tooltip="Overrides the decoder weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "decoder.weight_dtype") row += 1 if has_text_encoder: # decoder text encoder weight dtype - components.label(self.scroll_frame, row, 3, "Override Decoder Text Encoder Data Type", + components.label(frame, row, 3, "Override Decoder Text Encoder Data Type", tooltip="Overrides the decoder text encoder weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "decoder_text_encoder.weight_dtype") row += 1 # decoder vqgan weight dtype - components.label(self.scroll_frame, row, 3, "Override Decoder VQGAN Data Type", + components.label(frame, row, 3, "Override Decoder VQGAN Data Type", tooltip="Overrides the decoder vqgan weight data type") - components.options_kv(self.scroll_frame, row, 4, self.__create_dtype_options(), + components.options_kv(frame, row, 4, self.__create_dtype_options(), self.ui_state, "decoder_vqgan.weight_dtype") row += 1 @@ -505,20 +611,21 @@ def __create_decoder_components( def __create_output_components( self, + frame, row: int, allow_safetensors: bool = False, allow_diffusers: bool = False, allow_legacy_safetensors: bool = False, ) -> int: # output model destination - components.label(self.scroll_frame, row, 0, "Model Output Destination", + components.label(frame, row, 0, "Model Output Destination", tooltip="Filename or directory where the output model is saved") - components.file_entry(self.scroll_frame, row, 1, self.ui_state, "output_model_destination", is_output=True) + components.file_entry(frame, row, 1, self.ui_state, "output_model_destination", is_output=True) # output data type - components.label(self.scroll_frame, row, 3, "Output Data Type", + components.label(frame, row, 3, "Output Data Type", tooltip="Precision to use when saving the output model") - components.options_kv(self.scroll_frame, row, 4, [ + components.options_kv(frame, row, 4, [ ("float16", DataType.FLOAT_16), ("float32", DataType.FLOAT_32), ("bfloat16", DataType.BFLOAT_16), @@ -537,17 +644,17 @@ def __create_output_components( # if allow_legacy_safetensors: # formats.append(("Legacy Safetensors", ModelFormat.LEGACY_SAFETENSORS)) - components.label(self.scroll_frame, row, 0, "Output Format", + components.label(frame, row, 0, "Output Format", tooltip="Format to use when saving the output model") - components.options_kv(self.scroll_frame, row, 1, formats, self.ui_state, "output_model_format") + components.options_kv(frame, row, 1, formats, self.ui_state, "output_model_format") # include config - components.label(self.scroll_frame, row, 3, "Include Config", + components.label(frame, row, 3, "Include Config", tooltip="Include the training configuration in the final model. Only supported for safetensors files. " "None: No config is included. " "Settings: All training settings are included. " "All: All settings, including the samples and concepts are included.") - components.options_kv(self.scroll_frame, row, 4, [ + components.options_kv(frame, row, 4, [ ("None", ConfigPart.NONE), ("Settings", ConfigPart.SETTINGS), ("All", ConfigPart.ALL), diff --git a/modules/ui/SampleWindow.py b/modules/ui/SampleWindow.py index a93b3593f..15a7dfee2 100644 --- a/modules/ui/SampleWindow.py +++ b/modules/ui/SampleWindow.py @@ -18,6 +18,7 @@ from modules.util.enum.EMAMode import EMAMode from modules.util.enum.FileType import FileType from modules.util.enum.TrainingMethod import TrainingMethod +from modules.util.ModuleFilter import ModuleFilter from modules.util.time_util import get_string_timestamp from modules.util.ui import components from modules.util.ui.ui_utils import set_window_icon @@ -124,10 +125,15 @@ def __load_model(self) -> BaseModel: else: print("No backup found, loading without backup...") + quant_filters = [ + ModuleFilter(pattern, use_regex=self.initial_train_config.quantization_layer_filter_regex) + for pattern in self.initial_train_config.quantization_layer_filter.split(",") + ] model = model_loader.load( model_type=self.initial_train_config.model_type, model_names=model_names, weight_dtypes=self.initial_train_config.weight_dtypes(), + quant_filters=quant_filters, ) model.train_config = self.initial_train_config diff --git a/modules/ui/TrainUI.py b/modules/ui/TrainUI.py index 5e7141488..584b68f81 100644 --- a/modules/ui/TrainUI.py +++ b/modules/ui/TrainUI.py @@ -750,6 +750,7 @@ def __training_thread_function(self): trainer.start() if self.train_config.cloud.enabled: self.ui_state.get_var("secrets.cloud").update(self.train_config.secrets.cloud) + self.start_time = time.monotonic() trainer.train() except Exception: diff --git a/modules/ui/TrainingTab.py b/modules/ui/TrainingTab.py index cb769227c..cb67ba1f9 100644 --- a/modules/ui/TrainingTab.py +++ b/modules/ui/TrainingTab.py @@ -42,22 +42,6 @@ def __init__(self, master, train_config: TrainConfig, ui_state: UIState): master.grid_rowconfigure(0, weight=1) master.grid_columnconfigure(0, weight=1) - #layer filter: - self.layer_entry = None - self.layer_entry_fg_color = None - self.layer_entry_text_color = None - self.layer_selector = None - self.regex_label = None - self.regex_switch = None - self.presets = {} - self.presets_list = [] - self.prior_custom = "" - self.prior_selected = None - self.layer_filter_trace_id = self.ui_state.add_var_trace( - "layer_filter_preset", - self.__on_layer_filter_preset_change, - ) - self.scroll_frame = None self.refresh_ui() @@ -85,32 +69,6 @@ def refresh_ui(self): column_2.grid(row=0, column=2, sticky="nsew") column_2.grid_columnconfigure(0, weight=1) - if self.train_config.model_type.is_stable_diffusion(): #TODO simplify - self.presets = sd_presets - elif self.train_config.model_type.is_stable_diffusion_xl(): - self.presets = sdxl_presets - elif self.train_config.model_type.is_stable_diffusion_3(): - self.presets = sd3_presets - elif self.train_config.model_type.is_wuerstchen(): - self.presets = sc_presets - elif self.train_config.model_type.is_pixart(): - self.presets = pixart_presets - elif self.train_config.model_type.is_flux(): - self.presets = flux_presets - elif self.train_config.model_type.is_qwen(): - self.presets = qwen_presets - elif self.train_config.model_type.is_chroma(): - self.presets = chroma_presets - elif self.train_config.model_type.is_sana(): - self.presets = sana_presets - elif self.train_config.model_type.is_hunyuan_video(): - self.presets = hunyuan_video_presets - elif self.train_config.model_type.is_hi_dream(): - self.presets = hidream_presets - else: - self.presets = {"full": []} - self.presets_list = list(self.presets.keys()) + ["custom"] - if self.train_config.model_type.is_stable_diffusion(): self.__setup_stable_diffusion_ui(column_0, column_1, column_2) if self.train_config.model_type.is_stable_diffusion_3(): @@ -792,91 +750,41 @@ def __create_loss_frame(self, master, row, supports_vb_loss: bool = False): row += 1 def __create_layer_frame(self, master, row): - frame = ctk.CTkFrame(master=master, corner_radius=5) - frame.grid(row=row, column=0, padx=5, pady=5, sticky="nsew") - frame.grid_columnconfigure(0, weight=1) - - components.label(frame, 0, 0, "Layer Filter", - tooltip="Select a preset defining which layers to train, or select 'Custom' to define your own.\nA blank 'custom' field or 'Full' will train all layers.") - self.layer_selector = components.options( - frame, 0, 1, self.presets_list, self.ui_state, "layer_filter_preset", - command=self.__preset_set_layer_choice - ) - - self.layer_entry = components.entry( - frame, 1, 0, self.ui_state, "layer_filter", - tooltip="Comma-separated list of diffusion layers to train. Regular expressions (if toggled) are supported. Any model layer with a matching name will be trained" - ) - self.layer_entry_fg_color = self.layer_entry.cget("fg_color") - self.layer_entry_text_color = self.layer_entry.cget("text_color") - - self.regex_label = components.label( - frame, 2, 0, "Use Regex", - tooltip="If enabled, layer filter patterns are interpreted as regular expressions. Otherwise, simple substring matching is used." - ) - self.regex_switch = components.switch( - frame, 2, 1, self.ui_state, "layer_filter_regex" + presets = [] + if self.train_config.model_type.is_stable_diffusion(): #TODO simplify + presets = sd_presets + elif self.train_config.model_type.is_stable_diffusion_xl(): + presets = sdxl_presets + elif self.train_config.model_type.is_stable_diffusion_3(): + presets = sd3_presets + elif self.train_config.model_type.is_wuerstchen(): + presets = sc_presets + elif self.train_config.model_type.is_pixart(): + presets = pixart_presets + elif self.train_config.model_type.is_flux(): + presets = flux_presets + elif self.train_config.model_type.is_qwen(): + presets = qwen_presets + elif self.train_config.model_type.is_chroma(): + presets = chroma_presets + elif self.train_config.model_type.is_sana(): + presets = sana_presets + elif self.train_config.model_type.is_hunyuan_video(): + presets = hunyuan_video_presets + elif self.train_config.model_type.is_hi_dream(): + presets = hidream_presets + else: + presets = {"full": []} + components.layer_filter_entry(master, row, 0, self.ui_state, + preset_var_name="layer_filter_preset", presets=presets, + preset_label="Layer Filter", + preset_tooltip="Select a preset defining which layers to train, or select 'Custom' to define your own.\nA blank 'custom' field or 'Full' will train all layers.", + entry_var_name="layer_filter", + entry_tooltip="Comma-separated list of diffusion layers to train. Regular expressions (if toggled) are supported. Any model layer with a matching name will be trained", + regex_var_name="layer_filter_regex", + regex_tooltip="If enabled, layer filter patterns are interpreted as regular expressions. Otherwise, simple substring matching is used.", ) - # Let the user set their own layer filter - if self.train_config.layer_filter and self.train_config.layer_filter_preset == "custom": - self.prior_custom = self.train_config.layer_filter - else: - self.prior_custom = "" - - self.layer_entry.grid_configure(columnspan=2, sticky="ew") - # Some configs will come with the layer_filter_preset unset or wrong for - # the new model, so let's set it now to a reasonable default so it hits - # the UI correctly. - if self.layer_selector.get() not in self.presets_list: - self.layer_selector.set(self.presets_list[0]) - self.__preset_set_layer_choice(self.layer_selector.get()) - - - def __preset_set_layer_choice(self, selected: str): - if not selected: - selected = self.presets_list[0] - - if selected == "custom": - # Restore prior custom text and allow editing + regex toggle - self.__show_layer_entry() - self.layer_entry.configure(state="normal", fg_color=self.layer_entry_fg_color, text_color=self.layer_entry_text_color) - self.layer_entry.cget('textvariable').set(self.prior_custom) - self.regex_label.grid() - self.regex_switch.grid() - else: - # Preserve custom text before overwriting - if self.prior_selected == "custom": - self.prior_custom = self.layer_entry.get() - - # Resolve preset definition (list[str] OR {'patterns': [...], 'regex': bool}) - preset_def = self.presets.get(selected, []) - if isinstance(preset_def, dict): - patterns = preset_def.get("patterns", []) - preset_uses_regex = bool(preset_def.get("regex", False)) - else: - patterns = preset_def - preset_uses_regex = False - - disabled_color = ("gray85", "gray17") - disabled_text_color = ("gray30", "gray70") - self.layer_entry.configure(state="disabled", fg_color=disabled_color, text_color=disabled_text_color) - self.layer_entry.cget('textvariable').set(",".join(patterns)) - - self.train_config.layer_filter = ",".join(patterns) - - self.train_config.layer_filter_regex_regex = preset_uses_regex - self.ui_state.get_var("layer_filter_regex").set(preset_uses_regex) - - self.regex_label.grid_remove() - self.regex_switch.grid_remove() - - if selected == "full" and not patterns: - self.__hide_layer_entry() - else: - self.__show_layer_entry() - - self.prior_selected = selected def __on_layer_filter_preset_change(self): if not self.layer_selector: diff --git a/modules/util/LayerOffloadConductor.py b/modules/util/LayerOffloadConductor.py index 79229271d..70511659d 100644 --- a/modules/util/LayerOffloadConductor.py +++ b/modules/util/LayerOffloadConductor.py @@ -793,7 +793,7 @@ def __module_to_device_except_layers( sub_module_parameters = set(sum([list(x.parameters()) for x in self.__layers], [])) def convert(t): - if t in sub_module_parameters: + if t in sub_module_parameters or t.is_meta: return t return t.to(device=device) diff --git a/modules/util/config/TrainConfig.py b/modules/util/config/TrainConfig.py index 4c7d0f7e5..362ec36c0 100644 --- a/modules/util/config/TrainConfig.py +++ b/modules/util/config/TrainConfig.py @@ -332,6 +332,8 @@ class TrainConfig(BaseConfig): layer_offload_fraction: float force_circular_padding: bool compile: bool + svd_dtype: DataType + svd_rank: int # data settings concept_file_name: str @@ -409,6 +411,11 @@ class TrainConfig(BaseConfig): # transformer transformer: TrainModelPartConfig + # quantization + quantization_layer_filter: str + quantization_layer_filter_preset: str + quantization_layer_filter_regex: bool + # text encoder text_encoder: TrainModelPartConfig text_encoder_layer_skip: int @@ -893,7 +900,9 @@ def default_values() -> 'TrainConfig': data.append(("enable_activation_offloading", True, bool, False)) data.append(("layer_offload_fraction", 0.0, float, False)) data.append(("force_circular_padding", False, bool, False)) - data.append(("compile", False, bool, False)) + data.append(("compile", True, bool, False)) + data.append(("svd_dtype", DataType.BFLOAT_16, DataType, False)) + data.append(("svd_rank", 128, int, False)) # data settings data.append(("concept_file_name", "training_concepts/concepts.json", str, False)) @@ -971,7 +980,7 @@ def default_values() -> 'TrainConfig': prior.weight_dtype = DataType.NONE data.append(("prior", prior, TrainModelPartConfig, False)) - # prior + # transformer transformer = TrainModelPartConfig.default_values() transformer.model_name = "" transformer.train = True @@ -980,6 +989,11 @@ def default_values() -> 'TrainConfig': transformer.weight_dtype = DataType.NONE data.append(("transformer", transformer, TrainModelPartConfig, False)) + #quantization layer filter + data.append(("quantization_layer_filter", "", str, False)) + data.append(("quantization_layer_filter_preset", "full", str, False)) + data.append(("quantization_layer_filter_regex", False, bool, False)) + # text encoder text_encoder = TrainModelPartConfig.default_values() text_encoder.train = True diff --git a/modules/util/enum/DataType.py b/modules/util/enum/DataType.py index dd15e01d8..43cee577e 100644 --- a/modules/util/enum/DataType.py +++ b/modules/util/enum/DataType.py @@ -12,7 +12,19 @@ class DataType(Enum): TFLOAT_32 = 'TFLOAT_32' INT_8 = 'INT_8' NFLOAT_4 = 'NFLOAT_4' + FLOAT_W8A8 = 'FLOAT_W8A8' + INT_W8A8 = 'INT_W8A8' + FLOAT_8_SVD = 'FLOAT_8_SVD' + NFLOAT_4_SVD = 'NFLOAT_4_SVD' + FLOAT_W8A8_SVD = 'FLOAT_W8A8_SVD' + INT_W8A8_SVD = 'INT_W8A8_SVD' GGUF = 'GGUF' + GGUF_A8_FLOAT = 'GGUF_A8_FLOAT' + GGUF_A8_INT = 'GGUF_A8_INT' + FLOAT_16_A8_FLOAT = 'FLOAT_16_A8_FLOAT' + FLOAT_16_A8_INT = 'FLOAT_16_A8_INT' + BFLOAT_16_A8_FLOAT = 'BFLOAT_16_A8_FLOAT' + BFLOAT_16_A8_INT = 'BFLOAT_16_A8_INT' def __str__(self): return self.value @@ -33,6 +45,14 @@ def torch_dtype( return torch.bfloat16 case DataType.TFLOAT_32: return torch.float32 + case DataType.FLOAT_16_A8_FLOAT: + return torch.float16 + case DataType.FLOAT_16_A8_INT: + return torch.float16 + case DataType.BFLOAT_16_A8_FLOAT: + return torch.bfloat16 + case DataType.BFLOAT_16_A8_INT: + return torch.bfloat16 case _: return None @@ -42,13 +62,36 @@ def enable_tf(self): def is_quantized(self): return self in [DataType.FLOAT_8, DataType.INT_8, - DataType.NFLOAT_4] + DataType.FLOAT_W8A8, + DataType.INT_W8A8, + DataType.NFLOAT_4, + DataType.FLOAT_8_SVD, + DataType.FLOAT_W8A8_SVD, + DataType.INT_W8A8_SVD, + DataType.NFLOAT_4_SVD] + + def is_gguf(self): + return self in [DataType.GGUF, + DataType.GGUF_A8_FLOAT, + DataType.GGUF_A8_INT] def quantize_fp8(self): - return self == DataType.FLOAT_8 + return self == DataType.FLOAT_8 or self == DataType.FLOAT_8_SVD def quantize_int8(self): return self == DataType.INT_8 + def quantize_fpW8A8(self): + return self == DataType.FLOAT_W8A8 or self == DataType.FLOAT_W8A8_SVD + + def quantize_intW8A8(self): + return self == DataType.INT_W8A8 or self == DataType.INT_W8A8_SVD + def quantize_nf4(self): - return self == DataType.NFLOAT_4 + return self == DataType.NFLOAT_4 or self == DataType.NFLOAT_4_SVD + + def quantize_svd(self): + return self in [DataType.FLOAT_8_SVD, + DataType.NFLOAT_4_SVD, + DataType.FLOAT_W8A8_SVD, + DataType.INT_W8A8_SVD] diff --git a/modules/util/quantization_util.py b/modules/util/quantization_util.py index 3a3200e23..a89034d93 100644 --- a/modules/util/quantization_util.py +++ b/modules/util/quantization_util.py @@ -1,15 +1,20 @@ +import os from collections.abc import Callable +from functools import partial -from modules.module.quantized.LinearFp8 import LinearFp8 from modules.module.quantized.mixin.QuantizedLinearMixin import QuantizedLinearMixin from modules.module.quantized.mixin.QuantizedModuleMixin import QuantizedModuleMixin +from modules.util.config.TrainConfig import TrainConfig from modules.util.enum.DataType import DataType +from modules.util.ModuleFilter import ModuleFilter import torch from torch import Tensor, nn from diffusers.quantizers.gguf.utils import GGUFLinear, dequantize_gguf_tensor +from tqdm import tqdm + try: from modules.module.quantized.LinearNf4 import LinearNf4 @@ -18,67 +23,94 @@ bnb = None LinearNf4 = None +def quantize_int8(x: Tensor, scale: float | Tensor) -> Tensor: + q = x.float().mul(1.0 / scale).round_().clamp_(-128.0, 127.0).to(torch.int8) + return q -def __create_nf4_linear_layer(module: nn.Linear, copy_parameters: bool) -> nn.Module: - bias = module.bias is not None +def quantize_int8_tensorwise_get_scale(x: Tensor) -> float: + abs_max = x.abs().max() + scale = (abs_max.float() / 127.0).clamp(min=1e-30) + return scale - quant_linear = LinearNf4( - in_features=module.in_features, - out_features=module.out_features, - bias=bias, - ) +def quantize_int8_tensorwise(x: Tensor) -> tuple[Tensor, float]: + scale = quantize_int8_tensorwise_get_scale(x) + q = quantize_int8(x, scale) + return q, scale - if copy_parameters: - quant_linear.weight.data = module.weight.data - if bias: - quant_linear.bias.data = module.bias.data +def quantize_int8_axiswise_get_scale(x: Tensor, dim: int) -> Tensor: + abs_max = x.abs().amax(dim=dim, keepdim=True) + scale = (abs_max.float() / 127.0).clamp(min=1e-30) + return scale - return quant_linear +def quantize_int8_axiswise(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: + scale = quantize_int8_axiswise_get_scale(x, dim) + q = quantize_int8(x, scale) + return q, scale +def quantize_fp8(x: Tensor, scale: float | Tensor) -> Tensor: + q = x.float().mul(1.0 / scale).clamp_(-448.0, 448.0).to(torch.float8_e4m3fn) + return q -def __create_int8_linear_layer(module: nn.Linear, copy_parameters: bool) -> nn.Module: - bias = module.bias is not None +def quantize_fp8_tensorwise_get_scale(x: Tensor) -> float: + abs_max = x.abs().max() + scale = (abs_max.float() / 448.0).clamp(min=1e-30) + return scale - quant_linear = bnb.nn.Linear8bitLt( - input_features=module.in_features, - output_features=module.out_features, - bias=bias, - has_fp16_weights=False, - ) +def quantize_fp8_axiswise_get_scale(x: Tensor, dim: int) -> Tensor: + abs_max = x.abs().amax(dim=dim, keepdim=True) + scale = (abs_max.float() / 448.0).clamp(min=1e-30) + return scale - if copy_parameters: - quant_linear.weight = type(quant_linear.weight)(module.weight) - if bias: - quant_linear.bias = type(quant_linear.bias)(module.bias) +def quantize_fp8_tensorwise(x: Tensor) -> tuple[Tensor, float]: + scale = quantize_fp8_tensorwise_get_scale(x) + q = quantize_fp8(x, scale) + return q, scale - return quant_linear +def quantize_fp8_axiswise(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: + scale = quantize_fp8_axiswise_get_scale(x, dim) + q = quantize_fp8(x, scale) + return q, scale +def dequantize(q: Tensor, scale: float | Tensor, compute_dtype: torch.dtype) -> Tensor: + return q.to(compute_dtype) * scale.to(compute_dtype) -def __create_fp8_linear_layer(module: nn.Linear, copy_parameters: bool) -> nn.Module: - bias = module.bias is not None - quant_linear = LinearFp8( +from modules.module.quantized.LinearA8 import LinearA8 +from modules.module.quantized.LinearFp8 import LinearFp8 +from modules.module.quantized.LinearGGUFA8 import LinearGGUFA8 +from modules.module.quantized.LinearSVD import BaseLinearSVD, make_svd_linear +from modules.module.quantized.LinearW8A8 import LinearW8A8 + + +def __create_linear_layer(construct_fn, module: nn.Linear, copy_parameters: bool) -> nn.Module: + bias = module.bias is not None + quant_linear = construct_fn( in_features=module.in_features, out_features=module.out_features, bias=bias, ) if copy_parameters: - quant_linear.weight = type(quant_linear.weight)(module.weight) + quant_linear.weight = type(quant_linear.weight)(module.weight, requires_grad=False) if bias: - quant_linear.bias = type(quant_linear.bias)(module.bias) + quant_linear.bias = type(quant_linear.bias)(module.bias, requires_grad=False) return quant_linear -def __replace_linear_layers_recursive( + +def __replace_linear_layers( parent_module: nn.Module, - convert_fn: Callable[[nn.Linear, bool], nn.Module], + construct_fn, keep_in_fp32_modules: list[str] | None = None, + filters: list[ModuleFilter] | None = None, copy_parameters: bool = False, name_prefix: str = "", visited_modules: set[int] | None = None, + convert_type = nn.Linear, ): + #both 'keep_in_fp32_modules' and 'filters' are layer filters: keep_in_fp32_modules is set by diffusers, 'filters' is set by the user. + #Apply both. 'keep_in_fp32_modules' only looks at attr_name, 'filters' looks at the entire key at the leafs: if keep_in_fp32_modules is None: keep_in_fp32_modules = [] @@ -87,17 +119,22 @@ def __replace_linear_layers_recursive( visited_modules = set() visited_modules.add(id(parent_module)) + if isinstance(parent_module, (nn.ModuleList, nn.Sequential)): for i, module in enumerate(parent_module): - if isinstance(module, nn.Linear): - quant_linear = convert_fn(module, copy_parameters) + if isinstance(module, convert_type): + if filters is not None and len(filters) > 0 and not any(f.matches(name_prefix) for f in filters): + continue + + quant_linear = __create_linear_layer(construct_fn, module, copy_parameters) parent_module[i] = quant_linear del module elif id(module) not in visited_modules: - __replace_linear_layers_recursive( + __replace_linear_layers( parent_module=module, - convert_fn=convert_fn, + construct_fn=construct_fn, keep_in_fp32_modules=keep_in_fp32_modules, + filters=filters, copy_parameters=copy_parameters, name_prefix=f"{name_prefix}[{i}]", visited_modules=visited_modules, @@ -108,79 +145,79 @@ def __replace_linear_layers_recursive( continue module = getattr(parent_module, attr_name) - if isinstance(module, nn.Linear): - quant_linear = convert_fn(module, copy_parameters) + if isinstance(module, convert_type): + key_name = attr_name if name_prefix == "" else f"{name_prefix}.{attr_name}" + if filters is not None and len(filters) > 0 and not any(f.matches(key_name) for f in filters): + continue + + quant_linear = __create_linear_layer(construct_fn, module, copy_parameters) setattr(parent_module, attr_name, quant_linear) del module elif isinstance(module, nn.Module) and id(module) not in visited_modules: - __replace_linear_layers_recursive( + __replace_linear_layers( parent_module=module, - convert_fn=convert_fn, + construct_fn=construct_fn, keep_in_fp32_modules=keep_in_fp32_modules, + filters=filters, copy_parameters=copy_parameters, - name_prefix=f"{name_prefix}.{attr_name}", + name_prefix=attr_name if name_prefix == "" else f"{name_prefix}.{attr_name}", visited_modules=visited_modules, ) -def __replace_linear_layers( +def replace_linear_with_quantized_layers( parent_module: nn.Module, - convert_fn: Callable[[nn.Linear, bool], nn.Module], + dtype: DataType, keep_in_fp32_modules: list[str] | None = None, + filters: list[ModuleFilter] | None = None, copy_parameters: bool = False, ): - __replace_linear_layers_recursive(parent_module, convert_fn, keep_in_fp32_modules, copy_parameters) - - #ensure that all Linear layers were replaced - #https://github.com/Nerogar/OneTrainer/issues/1050 - for name, module in parent_module.named_modules(): - assert (not isinstance(module, nn.Linear) - or isinstance(module, QuantizedLinearMixin) - or any(s in name.split('.') for s in keep_in_fp32_modules) - ), f"Linear layer {name} was not found in model for quantization" - -def replace_linear_with_nf4_layers( - parent_module: nn.Module, - keep_in_fp32_modules: list[str] | None = None, - copy_parameters: bool = False, -): - __replace_linear_layers( - parent_module=parent_module, - convert_fn=__create_nf4_linear_layer, - keep_in_fp32_modules=keep_in_fp32_modules, - copy_parameters=copy_parameters, - ) - - -def replace_linear_with_int8_layers( - parent_module: nn.Module, - keep_in_fp32_modules: list[str] | None = None, - copy_parameters: bool = False, -): - __replace_linear_layers( - parent_module=parent_module, - convert_fn=__create_int8_linear_layer, - keep_in_fp32_modules=keep_in_fp32_modules, - copy_parameters=copy_parameters, - ) - + if dtype.quantize_nf4(): + construct_fn = make_svd_linear(LinearNf4) if dtype.quantize_svd() else LinearNf4 + elif dtype.quantize_int8(): + construct_fn = partial(make_svd_linear(bnb.nn.Linear8bitLt) if dtype.quantize_svd() else bnb.nn.Linear8bitLt, has_fp16_weights=False) + elif dtype.quantize_fp8(): + construct_fn = make_svd_linear(LinearFp8) if dtype.quantize_svd() else LinearFp8 + elif dtype.quantize_intW8A8(): + construct_fn = partial(make_svd_linear(LinearW8A8) if dtype.quantize_svd() else LinearW8A8, dtype=torch.int8, compute_dtype=torch.bfloat16) #FIXME + elif dtype.quantize_fpW8A8(): + construct_fn = partial(make_svd_linear(LinearW8A8) if dtype.quantize_svd() else LinearW8A8, dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) #FIXME + elif dtype == DataType.GGUF_A8_INT: + construct_fn = partial(LinearGGUFA8, dtype=torch.int8, compute_dtype=torch.bfloat16) #FIXME + elif dtype == DataType.GGUF_A8_FLOAT: + construct_fn = partial(LinearGGUFA8, dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) #FIXME + elif dtype == DataType.BFLOAT_16_A8_INT or dtype == DataType.FLOAT_16_A8_INT: + construct_fn = partial(LinearA8, dtype=torch.int8) + elif dtype == DataType.BFLOAT_16_A8_FLOAT or dtype == DataType.FLOAT_16_A8_FLOAT: + construct_fn = partial(LinearA8, dtype=torch.float8_e4m3fn) + else: + return -def replace_linear_with_fp8_layers( - parent_module: nn.Module, - keep_in_fp32_modules: list[str] | None = None, - copy_parameters: bool = False, -): + convert_type = GGUFLinear if dtype.is_gguf() else nn.Linear __replace_linear_layers( parent_module=parent_module, - convert_fn=__create_fp8_linear_layer, + construct_fn=construct_fn, keep_in_fp32_modules=keep_in_fp32_modules, + filters=filters, copy_parameters=copy_parameters, + convert_type=convert_type, ) + #ensure that all Linear layers were replaced + #https://github.com/Nerogar/OneTrainer/issues/1050 + for name, module in parent_module.named_modules(): + assert (not isinstance(module, convert_type) + or isinstance(module, (QuantizedLinearMixin, LinearGGUFA8, LinearA8)) + or any(s in name.split('.') for s in keep_in_fp32_modules) + or (filters is not None and len(filters) > 0 and not any(f.matches(name) for f in filters)) + ), f"Linear layer {name} was not found in model for quantization" def is_quantized_parameter( module: nn.Module, parameter_name: str, ) -> bool: + if isinstance(module, BaseLinearSVD): + if parameter_name in ["svd_up", "svd_down"]: + return True if bnb is not None: if isinstance(module, LinearNf4): return parameter_name in [ @@ -194,18 +231,21 @@ def is_quantized_parameter( elif isinstance(module, bnb.nn.Linear8bitLt): return parameter_name == "weight" - if isinstance(module, LinearFp8): + if isinstance(module, (LinearFp8, LinearW8A8)): return parameter_name == "weight" return False -def quantize_layers(module: nn.Module, device: torch.device, train_dtype: DataType): +def quantize_layers(module: nn.Module, device: torch.device, train_dtype: DataType, config: TrainConfig): if module is not None: - for child_module in module.modules(): + cache_dir = config.cache_dir + "/quantization" + os.makedirs(cache_dir, exist_ok=True) + child_modules = list(module.modules()) + for child_module in tqdm(child_modules, desc="Quantizing model weights", total=len(child_modules), delay=5, smoothing=0.1): if isinstance(child_module, QuantizedModuleMixin): child_module.compute_dtype = train_dtype.torch_dtype() - child_module.quantize(device) + child_module.quantize(device=device, cache_dir=cache_dir, svd_dtype=config.svd_dtype.torch_dtype(), rank=config.svd_rank) def get_unquantized_weight(module: nn.Linear, dtype: torch.dtype, device: torch.device) -> Tensor: @@ -232,6 +272,9 @@ def get_offload_tensors(module: nn.Module) -> list[torch.Tensor]: tensors += [module.weight] if isinstance(module, nn.Linear) and module.bias is not None: tensors += [module.bias] + if isinstance(module, BaseLinearSVD): + tensors += [module.svd_up] + tensors += [module.svd_down] return tensors diff --git a/modules/util/triton_mm_8bit.py b/modules/util/triton_mm_8bit.py new file mode 100644 index 000000000..8250e1270 --- /dev/null +++ b/modules/util/triton_mm_8bit.py @@ -0,0 +1,121 @@ +#This is a 8bit matmul kernel adapted from the Triton tutorial here: +#https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + +#It is not optimized and about 10% slower than torch._int_mm and torch._scaled_mm +#However, the torch functions don't work on row-major rhs matrices: +#_scaled_mm fails, _int_mm automatically converts to column-major +# +#Converting to column-major is slow, which is significant because the weights matrix +#of a Linear layer is always column-major during the backward pass. +# +#In these cases, this Triton kernel is *much* faster because it can access the +#row-major weight matrix directly, using strided memory access + +import torch + +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128}, num_stages=3,num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3,num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128}, num_stages=3,num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5,num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5,num_warps=2), + + + ], + key=[ + 'QUANTIZED_M', #only tune roughly on M, because M is the transformer sequence length - can vary on data + 'N', + 'K', + 'stride_bk' #use stride of b as key, to autotune again for a strided rhs matrix (backward pass) + ], +) + +@triton.jit +def __mm_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + QUANTIZED_M, + FLOAT: tl.constexpr, +): + + pid_n = tl.program_id(axis=0) + pid_m = tl.program_id(axis=1) + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32 if FLOAT else tl.int32) + + for k in range(tl.cdiv(K, BLOCK_SIZE_K)): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K - k*BLOCK_SIZE_K) + b_mask = (offs_bn[None, :] < N) & (offs_k[:, None] < K - k*BLOCK_SIZE_K) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + + accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32 if FLOAT else tl.int32) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + +def mm_8bit(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + assert a.dtype == b.dtype, "Incompatible dtypes" + assert a.dtype in [torch.int8, torch.float8_e4m3fn] + + FLOAT = (a.dtype == torch.float8_e4m3fn) + + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float32 if FLOAT else torch.int32) + + def grid(META): + return (triton.cdiv(N, META['BLOCK_SIZE_N']) , triton.cdiv(M, META['BLOCK_SIZE_M']), ) + __mm_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + QUANTIZED_M = M // 64, + FLOAT = FLOAT + ) + return c diff --git a/modules/util/ui/UIState.py b/modules/util/ui/UIState.py index 516146823..0c858094b 100644 --- a/modules/util/ui/UIState.py +++ b/modules/util/ui/UIState.py @@ -49,6 +49,9 @@ def add_var_trace(self, name, command: Callable[[], None]) -> int: def remove_var_trace(self, name, trace_id): self.__var_traces[name].pop(trace_id) + def remove_all_var_traces(self, name): + self.__var_traces[name] = {} + def __call_var_traces(self, name): for trace in self.__var_traces[name].values(): trace() diff --git a/modules/util/ui/components.py b/modules/util/ui/components.py index 3043f32fc..961d213a0 100644 --- a/modules/util/ui/components.py +++ b/modules/util/ui/components.py @@ -296,6 +296,113 @@ def time_entry(master, row, column, ui_state: UIState, var_name: str, unit_var_n return frame +def layer_filter_entry(master, row, column, ui_state: UIState, preset_var_name: str, preset_label: str, preset_tooltip: str, presets, entry_var_name, entry_tooltip: str, regex_var_name, regex_tooltip: str, frame_color=None): + frame = ctk.CTkFrame(master=master, corner_radius=5, fg_color=frame_color) + frame.grid(row=row, column=column, padx=5, pady=5, sticky="nsew") + frame.grid_columnconfigure(0, weight=1) + + layer_entry = entry( + frame, 1, 0, ui_state, entry_var_name, + tooltip=entry_tooltip + ) + layer_entry_fg_color = layer_entry.cget("fg_color") + layer_entry_text_color = layer_entry.cget("text_color") + + regex_label = label( + frame, 2, 0, "Use Regex", + tooltip=regex_tooltip, + ) + regex_switch = switch( + frame, 2, 1, ui_state, regex_var_name + ) + + # Let the user set their own layer filter + # TODO + #if self.train_config.layer_filter and self.train_config.layer_filter_preset == "custom": + # self.prior_custom = self.train_config.layer_filter + #else: + # self.prior_custom = "" + + layer_entry.grid_configure(columnspan=2, sticky="ew") + + presets_list = list(presets.keys()) + ["custom"] + + + def hide_layer_entry(): + if layer_entry and layer_entry.winfo_manager(): + layer_entry.grid_remove() + + def show_layer_entry(): + if layer_entry and not layer_entry.winfo_manager(): + layer_entry.grid() + + + def preset_set_layer_choice(selected: str): + if not selected or selected not in presets_list: + selected = presets_list[0] + + if selected == "custom": + # Allow editing + regex toggle + show_layer_entry() + layer_entry.configure(state="normal", fg_color=layer_entry_fg_color, text_color=layer_entry_text_color) + #layer_entry.cget('textvariable').set("") + regex_label.grid() + regex_switch.grid() + else: + # Preserve custom text before overwriting + #if self.prior_selected == "custom": + # self.prior_custom = self.layer_entry.get() + + # Resolve preset definition (list[str] OR {'patterns': [...], 'regex': bool}) + preset_def = presets.get(selected, []) + if isinstance(preset_def, dict): + patterns = preset_def.get("patterns", []) + preset_uses_regex = bool(preset_def.get("regex", False)) + else: + patterns = preset_def + preset_uses_regex = False + + disabled_color = ("gray85", "gray17") + disabled_text_color = ("gray30", "gray70") + layer_entry.configure(state="disabled", fg_color=disabled_color, text_color=disabled_text_color) + layer_entry.cget('textvariable').set(",".join(patterns)) + + ui_state.get_var(entry_var_name).set(",".join(patterns)) + ui_state.get_var(regex_var_name).set(preset_uses_regex) + + regex_label.grid_remove() + regex_switch.grid_remove() + + if selected == "full" and not patterns: + hide_layer_entry() + else: + show_layer_entry() + +# self.prior_selected = selected + + label(frame, 0, 0, preset_label, + tooltip=preset_tooltip) + + + ui_state.remove_all_var_traces(preset_var_name) + + layer_selector = options( + frame, 0, 1, presets_list, ui_state, preset_var_name, + command=preset_set_layer_choice + ) + + def on_layer_filter_preset_change(): + if not layer_selector: + return + selected = ui_state.get_var(preset_var_name).get() + preset_set_layer_choice(selected) + + ui_state.add_var_trace( + preset_var_name, + on_layer_filter_preset_change, + ) + + preset_set_layer_choice(layer_selector.get()) def icon_button(master, row, column, text, command): component = ctk.CTkButton(master, text=text, width=40, command=command) diff --git a/training_presets/#chroma LoRA 16GB.json b/training_presets/#chroma LoRA 16GB.json index a99eecca4..0cbe71627 100644 --- a/training_presets/#chroma LoRA 16GB.json +++ b/training_presets/#chroma LoRA 16GB.json @@ -22,5 +22,7 @@ "timestep_distribution": "INVERTED_PARABOLA", "noising_weight": 7.7, "layer_filter": "attn,ff.net", - "layer_filter_preset": "attn-mlp" + "layer_filter_preset": "attn-mlp", + "quantization_layer_filter": "transformer_block", + "quantization_layer_filter_preset": "blocks" } diff --git a/training_presets/#chroma LoRA 24GB.json b/training_presets/#chroma LoRA 24GB.json index 5877009c6..76868401b 100644 --- a/training_presets/#chroma LoRA 24GB.json +++ b/training_presets/#chroma LoRA 24GB.json @@ -22,5 +22,7 @@ "timestep_distribution": "INVERTED_PARABOLA", "noising_weight": 7.7, "layer_filter": "attn,ff.net", - "layer_filter_preset": "attn-mlp" + "layer_filter_preset": "attn-mlp", + "quantization_layer_filter": "transformer_block", + "quantization_layer_filter_preset": "blocks" } diff --git a/training_presets/#chroma LoRA 8GB.json b/training_presets/#chroma LoRA 8GB.json index 6fa19670c..e8a5e556b 100644 --- a/training_presets/#chroma LoRA 8GB.json +++ b/training_presets/#chroma LoRA 8GB.json @@ -25,5 +25,7 @@ "timestep_distribution": "INVERTED_PARABOLA", "noising_weight": 7.7, "layer_filter": "attn,ff.net", - "layer_filter_preset": "attn-mlp" + "layer_filter_preset": "attn-mlp", + "quantization_layer_filter": "transformer_block", + "quantization_layer_filter_preset": "blocks" } diff --git a/training_presets/#qwen LoRA 16GB.json b/training_presets/#qwen LoRA 16GB.json index a101c788c..7d877b9fc 100644 --- a/training_presets/#qwen LoRA 16GB.json +++ b/training_presets/#qwen LoRA 16GB.json @@ -24,5 +24,7 @@ "output_dtype": "BFLOAT_16", "timestep_distribution": "LOGIT_NORMAL", "layer_filter": "attn,img_mlp,txt_mlp", - "layer_filter_preset": "attn-mlp" + "layer_filter_preset": "attn-mlp", + "quantization_layer_filter": "transformer_block", + "quantization_layer_filter_preset": "blocks" } diff --git a/training_presets/#qwen LoRA 24GB.json b/training_presets/#qwen LoRA 24GB.json index bd03b0cc2..66bd7324d 100644 --- a/training_presets/#qwen LoRA 24GB.json +++ b/training_presets/#qwen LoRA 24GB.json @@ -24,5 +24,7 @@ "output_dtype": "BFLOAT_16", "timestep_distribution": "LOGIT_NORMAL", "layer_filter": "attn,img_mlp,txt_mlp", - "layer_filter_preset": "attn-mlp" + "layer_filter_preset": "attn-mlp", + "quantization_layer_filter": "transformer_block", + "quantization_layer_filter_preset": "blocks" }