Skip to content

Commit c2bbe64

Browse files
committed
breaking(gptq): quantizing only 1 layer yield high perplexity
1 parent cf14124 commit c2bbe64

File tree

7 files changed

+892
-302
lines changed

7 files changed

+892
-302
lines changed

quantize/compress_rwkv.py

Lines changed: 0 additions & 239 deletions
This file was deleted.

quantize/gptq/datautils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pathlib
55
import tokenizers
66
import random
7-
from rwkv.model import RWKV
7+
from myRWKV import RWKV
88

99
from datasets import load_dataset
1010

quantize/gptq/quant.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ def pack(self, weight, bias, scales, zeros, g_idx = None):
202202

203203
intweight = []
204204
for idx in range(self.infeatures):
205-
intweight.append(torch.round((weight.data[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None])
205+
#OLD: intweight.append(torch.round((weight.data[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None])
206+
intweight.append(torch.round((weight.data[idx, :] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None])
206207
intweight = torch.cat(intweight,dim=1)
207208
intweight = intweight.t().contiguous()
208209
intweight = intweight.numpy().astype(np.uint32)
@@ -411,7 +412,7 @@ def pack(self, linear, scales, zeros, g_idx = None):
411412
qweight = qweight.astype(np.int32)
412413
self.qweight = torch.from_numpy(qweight)
413414

414-
zeros -= 1;
415+
zeros -= 1
415416
zeros = zeros.numpy().astype(np.uint32)
416417
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
417418
i = 0

quantize/measure_perplexity.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import torch
1111
from typing import List
1212
from rwkv.model import RWKV
13+
os.environ['RWKV_JIT_ON'] = '1'
14+
os.environ["RWKV_CUDA_ON"] = '0'
1315

1416
def parse_args():
1517
parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file')
@@ -56,9 +58,10 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str:
5658

5759
# ---
5860
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
59-
# device=torch.device('cpu')
61+
# device = torch.device('cpu')
6062

61-
model = RWKV(model=args.model_path, strategy='cuda fp16i8')
63+
#TODO: Why is PERPLEXITY SO DAMN HIGH ?
64+
model = RWKV(model=args.model_path, strategy='cuda fp16')
6265

6366
logits, state = None, None
6467
loss_sum: torch.Tensor = torch.tensor([0.0], device=device)
@@ -72,7 +75,7 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str:
7275
for i in range(run_count):
7376
token: int = test_tokens[i]
7477
target: int = test_tokens[i + 1]
75-
78+
7679
logits, state = model.forward([token], None if i == 0 else state)
7780

7881
if ignore_first_n_tokens == 0 or i + 1 >= ignore_first_n_tokens:
@@ -105,7 +108,7 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str:
105108
print(f'Average latency: {int((time.time() - start) * 1000 / run_count)} ms per token')
106109

107110
print()
108-
print(f'Model: {os.path.basename(args.model_path)}, '
109-
f'data: {os.path.basename(args.dataset_path)} with {token_count} tokens, '
110-
f'Ignored first {ignore_first_n_tokens} tokens, '
111+
print(f'Model: {os.path.basename(args.model_path)}\n'
112+
f'data: {os.path.basename(args.dataset_path)} with {token_count} tokens\n'
113+
f'Ignored first {ignore_first_n_tokens} tokens\n'
111114
f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}')

0 commit comments

Comments
 (0)