Skip to content

Commit c520d81

Browse files
authored
Fix (brevitas_examples/llm): fix transformers tests (#1446)
1 parent e95d5bc commit c520d81

File tree

5 files changed

+8
-23
lines changed

5 files changed

+8
-23
lines changed

requirements/requirements-llm.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ pandas
99
pydantic
1010
torch>=2.4
1111
tqdm
12-
transformers[sentencepiece]
12+
transformers[sentencepiece]<5.0

src/brevitas_examples/llm/llm_quant/data_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
import numpy as np
3838
from optimum.utils.normalized_config import NormalizedConfigManager
3939
import torch
40-
from torch.utils.data import DataLoader
4140
from transformers import AutoConfig
4241

4342
from brevitas_examples.llm.llm_quant.data import get_clm_dataset

src/brevitas_examples/llm/llm_quant/rotation_optimization.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
import transformers
1717
from transformers import Trainer
1818
from transformers.data.data_collator import InputDataClass
19-
from transformers.tokenization_utils import PreTrainedTokenizerBase
19+
20+
try:
21+
from transformers.tokenization_utils import PreTrainedTokenizerBase
22+
except:
23+
# This has changed in transformers v5
24+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
2025

2126
from brevitas.graph.calibrate import quantization_status_manager
2227
from brevitas.optim.cailey_sgd import CaileySGD

src/brevitas_examples/llm/llm_quant/run_utils.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,35 +20,16 @@
2020
"""
2121

2222
from contextlib import contextmanager
23-
import inspect
2423

2524
from optimum.utils.normalized_config import NormalizedConfigManager
2625
import torch
2726
from torch.utils._python_dispatch import TorchDispatchMode
2827
from torch.utils._pytree import tree_map
2928
from transformers import AutoConfig
30-
from transformers.utils.fx import symbolic_trace
3129

3230
from brevitas.fx.value_tracer import ValueProxy
3331

3432

35-
def get_fx(model, is_export=True):
36-
forward_signature = inspect.signature(model.forward).parameters
37-
if all(input_name in forward_signature
38-
for input_name in ["input_ids", "attention_mask", "past_key_values"]):
39-
input_names = ["input_ids", "attention_mask", "past_key_values"]
40-
if not is_export:
41-
input_names.remove('past_key_values')
42-
else:
43-
raise ValueError(
44-
f"Quantization with an FX graph is currently only supported for models taking `input_ids`, `attention_mask` and `past_key_values` as inputs. The model only has the following inputs: {forward_signature}"
45-
)
46-
47-
with torch.no_grad():
48-
model = symbolic_trace(model, input_names)
49-
return model
50-
51-
5233
def modify_dataloader(model_name_or_path, data, dtype):
5334
config = AutoConfig.from_pretrained(model_name_or_path)
5435

tests/brevitas_examples/test_llm_cases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class LLMPerplexityCases:
116116
"input_scale_type": "dynamic",
117117
"input_quant_type": "sym",
118118
"float_ppl": 32428.475,
119-
"quant_ppl": 32428.383},
119+
"quant_ppl": 32447.685546875},
120120
{
121121
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
122122
"act_equalization": "layerwise",

0 commit comments

Comments
 (0)