Skip to content

Commit d64aa05

Browse files
committed
2 parents d4dbc60 + c21a889 commit d64aa05

File tree

5 files changed

+48
-35
lines changed

5 files changed

+48
-35
lines changed

GPTQ.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,9 @@ def __init__(
150150
}
151151

152152
# trace model for one input
153-
one_input = [multi.values[0] for multi in inputs]
153+
one_input = [multi.values[0].cpu() for multi in inputs]
154154
exported_model = torch._dynamo.export(
155-
model, aten_graph=True, pre_dispatch=True, tracing_mode="fake"
155+
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
156156
)(*one_input)
157157
super().__init__(exported_model.graph_module)
158158
self.new_state_dict = model.state_dict()

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ This is *NOT* intended to be a "framework" or "library" - it is intended to show
1414

1515
For an in-depth walkthrough of what's in this codebase, see this [blog post](https://pytorch.org/blog/accelerating-generative-ai-2/).
1616

17+
## Examples
18+
In the spirit of keeping the repo minimal, here are various examples of extensions you can make to gpt-fast as PRs.
19+
- [Gemma support](https://github.com/pytorch-labs/gpt-fast/pull/115)
1720
## Supported Models
1821

1922
### LLaMA family
@@ -37,6 +40,7 @@ Projects inspired by gpt-fast in the community:
3740

3841
- [gpt-blazing](https://github.com/armed-gpt/gpt-blazing): applies the same performance optimization strategy to more models (e.g., baichuan2).
3942
- [gptfast](https://github.com/MDK8888/GPTFast): applies a subset of the performance optimizations to all Huggingface models
43+
- [gpt-accelera](https://github.com/Edward-Sun/gpt-accelera): extends `gpt-fast` to SFT/RM/PPO training and batched inference to optimize the throughput
4044

4145
## Installation
4246
[Download PyTorch nightly](https://pytorch.org/get-started/locally/)
@@ -53,6 +57,7 @@ Then login with `huggingface-cli login`
5357
## Downloading Weights
5458
Models tested/supported
5559
```text
60+
tinyllamas/stories{15,42,100}
5661
openlm-research/open_llama_7b
5762
meta-llama/Llama-2-7b-chat-hf
5863
meta-llama/Llama-2-13b-chat-hf

generate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ def _load_model(checkpoint_path, device, precision, use_tp):
235235
model = simple_quantizer.convert_for_runtime(use_cuda)
236236

237237
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
238+
if "model" in checkpoint and "stories" in str(checkpoint_path):
239+
checkpoint = checkpoint["model"]
238240
model.load_state_dict(checkpoint, assign=True)
239241

240242
if use_tp:

model.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def from_name(cls, name: str):
6363
"34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf
6464
"70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672),
6565
"Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
66+
"stories15M": dict(n_layer=6, n_head=6, dim=288),
67+
"stories110M": dict(n_layer=12, n_head=12, dim=768),
6668
}
6769

6870
class KVCache(nn.Module):
@@ -105,10 +107,16 @@ def setup_caches(self, max_batch_size, max_seq_length):
105107
max_seq_length = find_multiple(max_seq_length, 8)
106108
self.max_seq_length = max_seq_length
107109
self.max_batch_size = max_batch_size
110+
dtype = self.output.weight.dtype
111+
# For quantized layers, dtype is encoded in scales
112+
if hasattr(self.output, "scales"):
113+
dtype = self.output.scales.dtype
114+
elif hasattr(self.output, "scales_and_zeros"):
115+
dtype = self.output.scales_and_zeros.dtype
108116
for b in self.layers:
109-
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim)
117+
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
110118

111-
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base)
119+
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
112120
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
113121

114122
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
@@ -220,14 +228,15 @@ def forward(self, x: Tensor) -> Tensor:
220228

221229

222230
def precompute_freqs_cis(
223-
seq_len: int, n_elem: int, base: int = 10000
231+
seq_len: int, n_elem: int, base: int = 10000,
232+
dtype: torch.dtype = torch.bfloat16
224233
) -> Tensor:
225234
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
226235
t = torch.arange(seq_len, device=freqs.device)
227236
freqs = torch.outer(t, freqs)
228237
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
229238
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
230-
return cache.to(dtype=torch.bfloat16)
239+
return cache.to(dtype=dtype)
231240

232241

233242
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:

quantize.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,9 @@ def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_til
365365
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
366366
return weight_int4pack, scales_and_zeros
367367

368+
def _calc_padded_size(k, groupsize=1, innner_k_tiles=1):
369+
from model import find_multiple
370+
return find_multiple(k, 1024)
368371

369372
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
370373
origin_x_size = x.size()
@@ -378,29 +381,24 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
378381
def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
379382
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
380383

381-
def replace_linear_int4(module, groupsize, inner_k_tiles, padding, use_cuda):
384+
def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda):
382385
for name, child in module.named_children():
383386
if isinstance(child, nn.Linear):
384-
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
387+
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed:
385388
setattr(module, name, WeightOnlyInt4Linear(
386389
child.in_features, child.out_features, bias=False,
387-
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False, use_cuda=use_cuda
388-
))
389-
elif padding:
390-
setattr(module, name, WeightOnlyInt4Linear(
391-
child.in_features, child.out_features, bias=False,
392-
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True, use_cuda=use_cuda
390+
groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda
393391
))
394392
else:
395-
replace_linear_int4(child, groupsize, inner_k_tiles, padding, use_cuda)
393+
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda)
396394

397395

398396
class WeightOnlyInt4QuantHandler:
399-
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
397+
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True):
400398
self.mod = mod
401399
self.groupsize = groupsize
402400
self.inner_k_tiles = inner_k_tiles
403-
self.padding = padding
401+
self.padding_allowed = padding_allowed
404402
assert groupsize in [32, 64, 128, 256]
405403
assert inner_k_tiles in [2, 4, 8]
406404

@@ -417,7 +415,7 @@ def create_quantized_state_dict(self):
417415

418416
weight = mod.weight.data
419417
if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
420-
if self.padding:
418+
if self.padding_allowed:
421419
from model import find_multiple
422420
import torch.nn.functional as F
423421
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
@@ -436,7 +434,7 @@ def create_quantized_state_dict(self):
436434
return cur_state_dict
437435

438436
def convert_for_runtime(self, use_cuda):
439-
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda)
437+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda)
440438
return self.mod
441439

442440
class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
@@ -460,7 +458,10 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
460458
# we need to do the padding here, both for q and the qparams if necessary
461459
def make_names_and_values_dict_func(q, qparams):
462460
k = q.shape[1]
463-
new_k = find_multiple(k, 1024)
461+
if not _check_linear_int4_k(k, groupsize, inner_k_tiles):
462+
new_k = find_multiple(k, 1024)
463+
else:
464+
new_k = k
464465
# how much we need to pad the weight
465466
delta_k = new_k - q.shape[1]
466467
final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles)
@@ -485,11 +486,11 @@ class WeightOnlyInt4Linear(torch.nn.Module):
485486

486487
def __init__(
487488
self, in_features: int, out_features: int,
488-
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True, use_cuda=True,
489+
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True,
489490
) -> None:
490491
super().__init__()
491-
self.padding = padding
492-
if padding:
492+
self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
493+
if self.padding:
493494
from model import find_multiple
494495
self.origin_in_features = in_features
495496
in_features = find_multiple(in_features, 1024)
@@ -502,16 +503,10 @@ def __init__(
502503

503504
assert out_features % 8 == 0, "require out_features % 8 == 0"
504505
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
505-
if use_cuda:
506-
self.register_buffer(
507-
"weight",
508-
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
509-
)
510-
else:
511-
self.register_buffer(
512-
"weight",
513-
torch.empty((out_features, in_features // 2), dtype=torch.uint8)
514-
)
506+
self.register_buffer(
507+
"weight",
508+
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
509+
)
515510
self.register_buffer(
516511
"scales_and_zeros",
517512
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
@@ -544,7 +539,7 @@ def quantize(
544539
device: str = default_device,
545540
) -> None:
546541
assert checkpoint_path.is_file(), checkpoint_path
547-
542+
device = 'cpu'
548543
precision = torch.bfloat16
549544

550545
print("Loading model ...")
@@ -554,6 +549,8 @@ def quantize(
554549
model = Transformer.from_name(checkpoint_path.parent.name)
555550

556551
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
552+
if "model" in checkpoint and "stories" in str(checkpoint_path):
553+
checkpoint = checkpoint["model"]
557554
model.load_state_dict(checkpoint, assign=True)
558555
model = model.to(dtype=precision, device=device)
559556

@@ -597,7 +594,7 @@ def quantize(
597594

598595
dir_name = checkpoint_path.parent
599596
base_name = checkpoint_path.name
600-
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth")
597+
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.{device}.pth")
601598
else:
602599
raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")
603600

0 commit comments

Comments
 (0)