Skip to content

Commit fea4d11

Browse files
baberabbshanhx2000
andauthored
[HF] fix quantization config (#3039)
* Try fixing issue 3026 which is caused by the quantization_config argument introduced in Commit 758c5ed. The argument is in Dict type, but for a GPTQ quantized model, it has a conflict with the huggingface interface which expects QuantizationConfigMixin type. Current solution is removing quantization_config argument in HFLM._create_model() of lm_eval/models/huggingface.py. Require further modification to restore the functionality provided by the previous commit. * wrap quantization_config in AutoQuantizationConfig * handle quantization config not dict * wrap quantization_config in AutoQuantizationConfig if dict --------- Co-authored-by: shanhx2000 <[email protected]>
1 parent 6b3f3f7 commit fea4d11

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

lm_eval/models/huggingface.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from datetime import timedelta
55
from pathlib import Path
6-
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
6+
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
77

88
import jinja2
99
import torch
@@ -17,8 +17,6 @@
1717
from accelerate.utils import get_max_memory
1818
from huggingface_hub import HfApi
1919
from packaging import version
20-
from peft import PeftModel
21-
from peft import __version__ as PEFT_VERSION
2220
from tqdm import tqdm
2321
from transformers.models.auto.modeling_auto import (
2422
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
@@ -40,6 +38,9 @@
4038
)
4139

4240

41+
if TYPE_CHECKING:
42+
from transformers.quantizers import AutoQuantizationConfig
43+
4344
eval_logger = logging.getLogger(__name__)
4445

4546

@@ -188,6 +189,13 @@ def __init__(
188189
add_bos_token=add_bos_token,
189190
)
190191

192+
if (
193+
quantization_config := getattr(self.config, "quantization_config", None)
194+
) is not None and isinstance(quantization_config, dict):
195+
from transformers.quantizers import AutoQuantizationConfig
196+
197+
quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
198+
191199
# if we passed `pretrained` as a string, initialize our model now
192200
if isinstance(pretrained, str):
193201
self._create_model(
@@ -205,7 +213,7 @@ def __init__(
205213
autogptq=autogptq,
206214
gptqmodel=gptqmodel,
207215
gguf_file=gguf_file,
208-
quantization_config=getattr(self.config, "quantization_config", None),
216+
quantization_config=quantization_config,
209217
subfolder=subfolder,
210218
**kwargs,
211219
)
@@ -554,7 +562,7 @@ def _create_model(
554562
autogptq: Optional[Union[bool, str]] = False,
555563
gptqmodel: Optional[bool] = False,
556564
gguf_file: Optional[str] = None,
557-
quantization_config: Optional[Dict[str, Any]] = None,
565+
quantization_config: Optional["AutoQuantizationConfig"] = None,
558566
subfolder: str = "",
559567
**kwargs,
560568
) -> None:
@@ -649,6 +657,9 @@ def _create_model(
649657
)
650658

651659
if peft:
660+
from peft import PeftModel
661+
from peft import __version__ as PEFT_VERSION
662+
652663
if model_kwargs.get("load_in_4bit", None):
653664
if version.parse(PEFT_VERSION) < version.parse("0.4.0"):
654665
raise AssertionError("load_in_4bit requires peft >= 0.4.0")

0 commit comments

Comments
 (0)