Skip to content

Commit aa1240a

Browse files
authored
fix: transformers deprecate load_in_Xbit in model_kwargs (#3205)
* fix: transformers deprecate load_in_Xbit in model_kwargs * fix: test to read from quantization_config kwarg * fix: test * fix: access * fix: test weirdly entering incorrect config
1 parent 4cdfdfe commit aa1240a

File tree

2 files changed

+22
-24
lines changed

2 files changed

+22
-24
lines changed

src/axolotl/loaders/model.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -515,9 +515,6 @@ def _set_quantization_config(self):
515515
if self.cfg.model_quantization_config_kwargs:
516516
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
517517
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
518-
else:
519-
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
520-
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
521518

522519
if self.cfg.gptq:
523520
if not hasattr(self.model_config, "quantization_config"):
@@ -552,9 +549,7 @@ def _set_quantization_config(self):
552549
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
553550
**self.model_config.quantization_config
554551
)
555-
elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
556-
"load_in_4bit", False
557-
):
552+
elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit:
558553
bnb_config = {
559554
"load_in_4bit": True,
560555
"llm_int8_threshold": 6.0,
@@ -580,9 +575,7 @@ def _set_quantization_config(self):
580575
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
581576
**bnb_config,
582577
)
583-
elif self.cfg.adapter == "lora" and self.model_kwargs.get(
584-
"load_in_8bit", False
585-
):
578+
elif self.cfg.adapter == "lora" and self.cfg.load_in_8bit:
586579
bnb_config = {
587580
"load_in_8bit": True,
588581
}
@@ -596,11 +589,6 @@ def _set_quantization_config(self):
596589
**bnb_config,
597590
)
598591

599-
# no longer needed per https://github.com/huggingface/transformers/pull/26610
600-
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
601-
self.model_kwargs.pop("load_in_8bit", None)
602-
self.model_kwargs.pop("load_in_4bit", None)
603-
604592
def _set_attention_config(self):
605593
"""Sample packing uses custom FA2 patch"""
606594
if self.cfg.attn_implementation:

tests/test_loaders.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,26 @@ def test_set_quantization_config(
8080
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
8181
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
8282
)
83-
elif load_in_8bit and self.cfg.adapter is not None:
84-
assert self.model_loader.model_kwargs["load_in_8bit"]
85-
elif load_in_4bit and self.cfg.adapter is not None:
86-
assert self.model_loader.model_kwargs["load_in_4bit"]
87-
88-
if (self.cfg.adapter == "qlora" and load_in_4bit) or (
89-
self.cfg.adapter == "lora" and load_in_8bit
90-
):
91-
assert self.model_loader.model_kwargs.get(
92-
"quantization_config", BitsAndBytesConfig
83+
84+
if self.cfg.adapter == "qlora" and load_in_4bit:
85+
assert isinstance(
86+
self.model_loader.model_kwargs.get("quantization_config"),
87+
BitsAndBytesConfig,
88+
)
89+
90+
assert (
91+
self.model_loader.model_kwargs["quantization_config"]._load_in_4bit
92+
is True
93+
)
94+
if self.cfg.adapter == "lora" and load_in_8bit:
95+
assert isinstance(
96+
self.model_loader.model_kwargs.get("quantization_config"),
97+
BitsAndBytesConfig,
98+
)
99+
100+
assert (
101+
self.model_loader.model_kwargs["quantization_config"]._load_in_8bit
102+
is True
93103
)
94104

95105
def test_message_property_mapping(self):

0 commit comments

Comments
 (0)