Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
1e8fece
Compile int svd
dxqb Oct 4, 2025
bbf6eb8
- fix cache dir
dxqb Oct 5, 2025
2676d33
hide checkpoints from LoRA saving
dxqb Oct 5, 2025
dc289fa
fix buffer registration
dxqb Oct 6, 2025
73822b0
fix buffer registration
dxqb Oct 6, 2025
4821c9e
various
dxqb Oct 13, 2025
cb02a4c
various
dxqb Oct 13, 2025
efb9073
various
dxqb Oct 13, 2025
35ba023
cleanup
dxqb Oct 13, 2025
5633b21
torch.compile bug workaround
dxqb Oct 14, 2025
c37f805
same workaround for Qwen
dxqb Oct 14, 2025
d7532dc
gguf
dxqb Oct 15, 2025
6883981
merge
dxqb Oct 15, 2025
2de19c9
gguf
dxqb Oct 15, 2025
a3cd936
bugfix
dxqb Oct 15, 2025
bcf0b65
requirements
dxqb Oct 15, 2025
0dcb3dc
merge
dxqb Oct 15, 2025
5a2e590
name changes, axis wise
dxqb Oct 16, 2025
44a969b
Merge branch 'compile_int8_svd' into compile_int8_svd_gguf
dxqb Oct 16, 2025
882154d
merge
dxqb Oct 16, 2025
e5317d3
big type hint
dxqb Oct 16, 2025
7f07748
Merge branch 'compile_int8_svd' into compile_int8_svd_gguf
dxqb Oct 16, 2025
c8cb33b
use axis-wise quantization for both forward and backward
dxqb Oct 16, 2025
c6011d0
Merge branch 'upstream' into gguf
dxqb Oct 16, 2025
b772d42
merge
dxqb Oct 16, 2025
3167b90
merge
dxqb Oct 16, 2025
2d4a0c3
initial
dxqb Oct 16, 2025
68f0f71
merge #1060
dxqb Oct 16, 2025
0e282a2
merge
dxqb Oct 16, 2025
71af1f0
ui fix
dxqb Oct 16, 2025
1606c85
Merge branch 'compile_int8_svd' into compile_int8_svd_gguf
dxqb Oct 16, 2025
0f58a5e
GGUF with DoRA
dxqb Oct 16, 2025
8ca5782
GGUF with DoRA
dxqb Oct 16, 2025
881e7e5
GGUF A8 float bugfix
dxqb Oct 17, 2025
25ccc0c
improve check for #1050
dxqb Oct 17, 2025
cfc1492
Merge branch 't5' into compile_int8_svd
dxqb Oct 17, 2025
b4d8f30
improve check for #1050
dxqb Oct 17, 2025
827fa11
improve check for #1050
dxqb Oct 17, 2025
da296d6
re-enabled int W8A8
dxqb Oct 17, 2025
f273ac3
Merge branch 'compile_int8_svd' into compile_int8_svd_gguf
dxqb Oct 17, 2025
a3de776
merge
dxqb Oct 23, 2025
7351e9a
Merge branch 'upstream' into compile_int8_svd
dxqb Oct 26, 2025
56902b6
only quantize activations if GGUF weights are actually quantized
dxqb Oct 28, 2025
eaf4fe2
make layer filter a component
dxqb Nov 2, 2025
867b84c
quantization layer filter
dxqb Nov 2, 2025
e5d0317
add blocks preset
dxqb Nov 2, 2025
1f45c04
Merge branch 'blocks' into quant_layer_filter
dxqb Nov 2, 2025
30b7cca
merge
dxqb Nov 2, 2025
0242d16
quantization filter in presets
dxqb Nov 2, 2025
e5f5b0b
Merge branch 'quant_layer_filter' into compile_int8_svd
dxqb Nov 2, 2025
9e897b8
#1054
dxqb Nov 2, 2025
9aa7973
Merge branch 'config-prefix' into compile_int8_svd
dxqb Nov 2, 2025
3448801
bugfix
dxqb Nov 2, 2025
40d61c5
Merge branch 'quant_layer_filter' into compile_int8_svd
dxqb Nov 2, 2025
cf70b7a
Merge branch 'upstream' into compile_int8_svd
dxqb Nov 2, 2025
5bc6c5a
smaller eps, because gradients for some models are close to 1e-12
dxqb Nov 3, 2025
fb1e8a8
compile benchmarks
dxqb Nov 4, 2025
4ca84db
remove cast
dxqb Nov 4, 2025
41e44a2
detach dequantized weights
dxqb Nov 4, 2025
b3f69ae
name changes
dxqb Nov 7, 2025
f9c12a8
move code
dxqb Nov 7, 2025
5db4161
fix circular dependency
dxqb Nov 7, 2025
cd3f971
ensure contiguous grad output
dxqb Nov 7, 2025
41a05b9
W16A8
dxqb Nov 7, 2025
228c976
fix comment
dxqb Nov 7, 2025
a030b23
DataType bugfix
dxqb Nov 7, 2025
9159d24
avoid attention mask
dxqb Nov 8, 2025
7760af3
Merge branch 'avoid_attn_mask' into compile_int8_svd
dxqb Nov 8, 2025
9559ebf
disable bug workaround - can currently not be reproduced and because …
dxqb Nov 8, 2025
49d2bc4
pad sequence length if an attention mask is necessary anyway
dxqb Nov 9, 2025
2ecf834
merge
dxqb Nov 9, 2025
dc71ef5
Merge branch 'upstream' into compile_int8_svd
dxqb Nov 11, 2025
c756b2c
merge
dxqb Nov 14, 2025
d6bb1ff
Merge branch 'upstream' into compile_int8_svd
dxqb Nov 14, 2025
27dc59d
merge
dxqb Nov 14, 2025
69f0fa1
merge fix
dxqb Nov 14, 2025
1739bfe
Fixes [Bug]: Layer filter isn't configured correct if a preset is loaded
O-J1 Nov 16, 2025
968b2a9
Simplify tooltip text for layer filter
dxqb Nov 16, 2025
6f2d2f5
Tweak tooltip text a little more
O-J1 Nov 16, 2025
ea29f8f
Merge branch 'upstream' into quant_layer_filter
dxqb Nov 16, 2025
606b7a8
merge with #1139
dxqb Nov 22, 2025
8ce4604
merge with upstream
dxqb Nov 22, 2025
c9a7d07
fix to Dtypes, to avoid leaving weights at float32
dxqb Nov 22, 2025
50982b7
UI update
dxqb Nov 22, 2025
16a6015
UI update
dxqb Nov 22, 2025
fde79ce
UI change and merge
dxqb Nov 22, 2025
bf505b4
Merge branch 'compile_int8_svd' into W16A8
dxqb Nov 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion modules/modelLoader/ChromaEmbeddingModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


class ChromaEmbeddingModelLoader(
Expand All @@ -32,6 +33,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> ChromaModel | None:
base_model_loader = ChromaModelLoader()
embedding_loader = ChromaEmbeddingLoader()
Expand All @@ -41,7 +43,7 @@ def load(
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
embedding_loader.load(model, model_names.embedding.model_name, model_names)

return model
4 changes: 3 additions & 1 deletion modules/modelLoader/ChromaFineTuneModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


class ChromaFineTuneModelLoader(
Expand All @@ -32,6 +33,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> ChromaModel | None:
base_model_loader = ChromaModelLoader()
embedding_loader = ChromaEmbeddingLoader()
Expand All @@ -41,7 +43,7 @@ def load(
self._load_internal_data(model, model_names.base_model)
model.model_spec = self._load_default_model_spec(model_type)

base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
embedding_loader.load(model, model_names.base_model, model_names)

return model
4 changes: 3 additions & 1 deletion modules/modelLoader/ChromaLoRAModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


class ChromaLoRAModelLoader(
Expand All @@ -33,6 +34,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> ChromaModel | None:
base_model_loader = ChromaModelLoader()
lora_model_loader = ChromaLoRALoader()
Expand All @@ -43,7 +45,7 @@ def load(
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
lora_model_loader.load(model, model_names)
embedding_loader.load(model, model_names.lora, model_names)

Expand Down
4 changes: 3 additions & 1 deletion modules/modelLoader/FluxEmbeddingModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


class FluxEmbeddingModelLoader(
Expand Down Expand Up @@ -34,6 +35,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> FluxModel | None:
base_model_loader = FluxModelLoader()
embedding_loader = FluxEmbeddingLoader()
Expand All @@ -43,7 +45,7 @@ def load(
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
embedding_loader.load(model, model_names.embedding.model_name, model_names)

return model
4 changes: 3 additions & 1 deletion modules/modelLoader/FluxFineTuneModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


class FluxFineTuneModelLoader(
Expand Down Expand Up @@ -34,6 +35,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> FluxModel | None:
base_model_loader = FluxModelLoader()
embedding_loader = FluxEmbeddingLoader()
Expand All @@ -43,7 +45,7 @@ def load(
self._load_internal_data(model, model_names.base_model)
model.model_spec = self._load_default_model_spec(model_type)

base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
embedding_loader.load(model, model_names.base_model, model_names)

return model
4 changes: 3 additions & 1 deletion modules/modelLoader/FluxLoRAModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


class FluxLoRAModelLoader(
Expand Down Expand Up @@ -35,6 +36,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> FluxModel | None:
base_model_loader = FluxModelLoader()
lora_model_loader = FluxLoRALoader()
Expand All @@ -45,7 +47,7 @@ def load(
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
lora_model_loader.load(model, model_names)
embedding_loader.load(model, model_names.lora, model_names)

Expand Down
4 changes: 3 additions & 1 deletion modules/modelLoader/HiDreamEmbeddingModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


class HiDreamEmbeddingModelLoader(
Expand All @@ -32,6 +33,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> HiDreamModel | None:
base_model_loader = HiDreamModelLoader()
embedding_loader = HiDreamEmbeddingLoader()
Expand All @@ -41,7 +43,7 @@ def load(
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
embedding_loader.load(model, model_names.embedding.model_name, model_names)

return model
4 changes: 3 additions & 1 deletion modules/modelLoader/HiDreamFineTuneModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


class HiDreamFineTuneModelLoader(
Expand All @@ -32,6 +33,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> HiDreamModel | None:
base_model_loader = HiDreamModelLoader()
embedding_loader = HiDreamEmbeddingLoader()
Expand All @@ -41,7 +43,7 @@ def load(
self._load_internal_data(model, model_names.base_model)
model.model_spec = self._load_default_model_spec(model_type)

base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
embedding_loader.load(model, model_names.base_model, model_names)

return model
4 changes: 3 additions & 1 deletion modules/modelLoader/HiDreamLoRAModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


class HiDreamLoRAModelLoader(
Expand All @@ -33,6 +34,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> HiDreamModel | None:
base_model_loader = HiDreamModelLoader()
lora_model_loader = HiDreamLoRALoader()
Expand All @@ -43,7 +45,7 @@ def load(
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
lora_model_loader.load(model, model_names)
embedding_loader.load(model, model_names.lora, model_names)

Expand Down
4 changes: 3 additions & 1 deletion modules/modelLoader/HunyuanVideoEmbeddingModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


class HunyuanVideoEmbeddingModelLoader(
Expand All @@ -32,6 +33,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> HunyuanVideoModel | None:
base_model_loader = HunyuanVideoModelLoader()
embedding_loader = HunyuanVideoEmbeddingLoader()
Expand All @@ -41,7 +43,7 @@ def load(
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
embedding_loader.load(model, model_names.embedding.model_name, model_names)

return model
4 changes: 3 additions & 1 deletion modules/modelLoader/HunyuanVideoFineTuneModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


class HunyuanVideoFineTuneModelLoader(
Expand All @@ -32,6 +33,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> HunyuanVideoModel | None:
base_model_loader = HunyuanVideoModelLoader()
embedding_loader = HunyuanVideoEmbeddingLoader()
Expand All @@ -41,7 +43,7 @@ def load(
self._load_internal_data(model, model_names.base_model)
model.model_spec = self._load_default_model_spec(model_type)

base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
embedding_loader.load(model, model_names.base_model, model_names)

return model
4 changes: 3 additions & 1 deletion modules/modelLoader/HunyuanVideoLoRAModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


class HunyuanVideoLoRAModelLoader(
Expand All @@ -33,6 +34,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> HunyuanVideoModel | None:
base_model_loader = HunyuanVideoModelLoader()
lora_model_loader = HunyuanVideoLoRALoader()
Expand All @@ -43,7 +45,7 @@ def load(
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
lora_model_loader.load(model, model_names)
embedding_loader.load(model, model_names.lora, model_names)

Expand Down
4 changes: 3 additions & 1 deletion modules/modelLoader/PixArtAlphaEmbeddingModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


class PixArtAlphaEmbeddingModelLoader(
Expand Down Expand Up @@ -34,6 +35,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> PixArtAlphaModel | None:
base_model_loader = PixArtAlphaModelLoader()
embedding_loader = PixArtAlphaEmbeddingLoader()
Expand All @@ -43,7 +45,7 @@ def load(
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
embedding_loader.load(model, model_names.embedding.model_name, model_names)

return model
4 changes: 3 additions & 1 deletion modules/modelLoader/PixArtAlphaFineTuneModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


class PixArtAlphaFineTuneModelLoader(
Expand Down Expand Up @@ -34,6 +35,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> PixArtAlphaModel | None:
base_model_loader = PixArtAlphaModelLoader()
embedding_loader = PixArtAlphaEmbeddingLoader()
Expand All @@ -43,7 +45,7 @@ def load(
self._load_internal_data(model, model_names.base_model)
model.model_spec = self._load_default_model_spec(model_type)

base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
embedding_loader.load(model, model_names.base_model, model_names)

return model
4 changes: 3 additions & 1 deletion modules/modelLoader/PixArtAlphaLoRAModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


class PixArtAlphaLoRAModelLoader(
Expand Down Expand Up @@ -35,6 +36,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> PixArtAlphaModel | None:
base_model_loader = PixArtAlphaModelLoader()
lora_model_loader = PixArtAlphaLoRALoader()
Expand All @@ -45,7 +47,7 @@ def load(
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
lora_model_loader.load(model, model_names)
embedding_loader.load(model, model_names.lora, model_names)

Expand Down
4 changes: 3 additions & 1 deletion modules/modelLoader/QwenFineTuneModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


class QwenFineTuneModelLoader(
Expand All @@ -31,13 +32,14 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> QwenModel | None:
base_model_loader = QwenModelLoader()
model = QwenModel(model_type=model_type)

self._load_internal_data(model, model_names.base_model)
model.model_spec = self._load_default_model_spec(model_type)

base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)

return model
Loading