Skip to content

Commit ef703b8

Browse files
authored
AWQ loader for transformers (eth-sri#254)
1 parent 22cca61 commit ef703b8

File tree

2 files changed

+80
-15
lines changed

2 files changed

+80
-15
lines changed

docs/docs/models/hf.md

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ By default, this exposes an [LMQL/LMTP inference API](https://github.com/eth-sri
2828

2929
**Endpoint and Port** By default, models will be served via port `8080`. To change this, you can specify the port via the `--port` option of the `lmql serve-model` command. On the client side, to connect to a model server running on a different port, you can specify the port when constructing an [`lmql.model`](../lib/generations.md#lmql-llm-objects) object:
3030

31-
```
31+
```python
3232
lmql.model("gpt2", endpoint="localhost:9999")
3333
```
3434

@@ -58,4 +58,30 @@ If you want more control over model loading and configuration, you can pass addi
5858

5959
```python
6060
lmql.model("local:gpt2", cuda=True)
61-
```
61+
```
62+
63+
## Quantization
64+
65+
Quantization reduces the precision of model parameters to shrink model size and boost inference speed with minimal accuracy loss. LMQL supports two quantization formats: AWQ (using AutoAWQ) and GPTQ (using AutoGPTQ).
66+
67+
### AutoAWQ
68+
69+
AWQ minimizes quantization error by protecting crucial weights, promoting model efficiency without sacrificing accuracy. It's ideal for scenarios requiring both compression and acceleration of LLMs.
70+
71+
Install AutoAWQ following the repo instructions. To use AWQ-quantized models, run:
72+
73+
To use `AWQ`-quantized models, first install [AutoAWQ](https://github.com/casper-hansen/AutoAWQ) using the instructions in the repo.
74+
75+
```bash
76+
lmql serve-model TheBloke/Mistral-7B-OpenOrca-AWQ --loader awq
77+
```
78+
79+
### AutoGPTQ
80+
81+
AutoGPTQ reduces model size while retaining performance by lowering the precision of model weights to 4 or 3 bits. It's suitable for efficient deployment and operation of LLMs on consumer-grade hardware.
82+
83+
Install [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) following the repo instructions. To use GPTQ-quantized models, run:
84+
85+
```bash
86+
lmql serve-model TheBloke/Arithmo-Mistral-7B-GPTQ --loader gptq
87+
```

src/lmql/models/lmtp/backends/transformers_model.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,37 +25,69 @@ def merge(kwargs1, kwargs2, prioritize="left"):
2525
class TransformersLLM(LMTPModel):
2626
def __init__(self, model_identifier, **kwargs):
2727
self.model_identifier = model_identifier
28+
29+
self.loader = kwargs.pop("loader", None)
30+
if self.loader is None:
31+
if '-gptq' in self.model_identifier.lower():
32+
self.loader = "gptq"
33+
elif '-awq' in self.model_identifier.lower():
34+
self.loader = "awq"
35+
else:
36+
self.loader = "transformers"
37+
2838
self.model_args = kwargs
29-
self.loader = kwargs.pop("loader", "transformers")
30-
31-
self.max_batch_size = kwargs.pop("batch_size", 32)
39+
self.max_batch_size = kwargs.get("batch_size", 32)
3240

3341
self.silent = kwargs.pop("silent", False)
3442

3543
if not self.silent:
3644
print("[Loading", self.model_identifier, "with", self.model_constructor() + "]", flush=True)
3745

38-
if self.loader == "auto-gptq":
46+
if self.loader == "gptq" or self.loader == "auto-gptq":
3947
from auto_gptq import AutoGPTQForCausalLM
4048
self.model = AutoGPTQForCausalLM.from_quantized(self.model_identifier, **self.model_args)
49+
elif self.loader == 'awq':
50+
from awq import AutoAWQForCausalLM
51+
awq_args = {
52+
'quant_filename': kwargs.pop("quant_filename", ''),
53+
"max_new_tokens": kwargs.pop("max_new_tokens", None),
54+
"trust_remote_code": kwargs.pop("trust_remote_code", True),
55+
"safetensors": kwargs.pop("safetensors", True),
56+
"fuse_layers": False, # TODO: Figure out why this is broken
57+
"max_memory": kwargs.pop("max_memory", None),
58+
"offload_folder": kwargs.pop("offload_folder", None),
59+
"batch_size": kwargs.get("batch_size", 16)
60+
}
61+
self.model = AutoAWQForCausalLM.from_quantized(self.model_identifier, **awq_args)
4162
else:
4263
from transformers import AutoModelForCausalLM
4364
self.model = AutoModelForCausalLM.from_pretrained(self.model_identifier, **self.model_args)
4465

66+
if self.loader == 'awq':
67+
self.device = self.model.model.device
68+
else:
69+
self.device = self.model.device
70+
4571
if not self.silent:
46-
print("[", self.model_identifier, " ready on device ", self.model.device,
72+
print("[", self.model_identifier, " ready on device ", self.device,
4773
flush=True, sep="", end="]\n")
4874

4975
@property
5076
def eos_token_id(self):
51-
return self.model.config.eos_token_id
77+
if self.loader == 'awq':
78+
return self.model.model.config.eos_token_id
79+
else:
80+
return self.model.config.eos_token_id
5281

5382
def score(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, **model_kwargs) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
54-
input_ids = torch.tensor(input_ids).to(self.model.device)
55-
attention_mask = torch.tensor(attention_mask).to(self.model.device)
83+
input_ids = torch.tensor(input_ids).to(self.device)
84+
attention_mask = torch.tensor(attention_mask).to(self.device)
5685

5786
# prepare model inputs
58-
model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs, attention_mask=attention_mask, eos_token_id=self.eos_token_id)
87+
if self.loader == 'awq':
88+
model_inputs = self.model.model.prepare_inputs_for_generation(input_ids, **model_kwargs, attention_mask=attention_mask, eos_token_id=self.eos_token_id)
89+
else:
90+
model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs, attention_mask=attention_mask, eos_token_id=self.eos_token_id)
5991
model_inputs["attention_mask"] = attention_mask
6092

6193
token_scores = []
@@ -76,8 +108,8 @@ def score(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, *
76108
def generate(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor,
77109
temperature: float, max_new_tokens: int,
78110
bias_tensor: torch.FloatTensor, streamer: TokenStreamer, **kwargs) -> LMTPModelResult:
79-
input_ids = torch.tensor(input_ids).to(self.model.device)
80-
attention_mask = torch.tensor(attention_mask).to(self.model.device)
111+
input_ids = torch.tensor(input_ids).to(self.device)
112+
attention_mask = torch.tensor(attention_mask).to(self.device)
81113

82114
generate_args = {
83115
"input_ids": input_ids,
@@ -117,20 +149,27 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
117149
return [BatchLogitsProcessor()]
118150

119151
def model_constructor(self):
120-
if self.loader == "auto-gptq":
152+
if self.loader == "gptq" or self.loader == "auto-gptq":
121153
return "AutoGPTQForCausalLM.from_quantized({})".format(format_call(self.model_identifier, **self.model_args))
154+
elif self.loader == 'awq':
155+
return "AutoAWQForCausalLM.from_quantized({})".format(format_call(self.model_identifier, **self.model_args))
122156
else:
123157
return "AutoModelForCausalLM.from_pretrained({})]".format(format_call(self.model_identifier, **self.model_args))
124158

125159
def version_info(self):
126160
global version_info
127161

128162
if len(version_info) == 0:
129-
if self.loader == "auto-gptq":
163+
if self.loader == "gptq" or self.loader == "auto-gptq":
130164
import auto_gptq
131165
version_info = {
132166
"auto_gptq": auto_gptq.__version__
133167
}
168+
elif self.loader == "awq":
169+
import awq
170+
version_info = {
171+
"awq": awq.__version__
172+
}
134173
else:
135174
import transformers
136175
version_info = {

0 commit comments

Comments
 (0)