Skip to content

Commit 56abdc2

Browse files
test cleanup: add deprecated marker, move benchmarks out
1 parent e3051fa commit 56abdc2

File tree

9 files changed

+612
-616
lines changed

9 files changed

+612
-616
lines changed

benchmarking/int8/int8_benchmark.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
tokenizer = AutoTokenizer.from_pretrained(model_name)
1818
input_ids = tokenizer([text] * 8, return_tensors="pt").input_ids.to(0)
1919

20-
max_memory = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB"
21-
2220
model = AutoModelForCausalLM.from_pretrained(
2321
model_name,
2422
device_map="auto",
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
Extracted from tests/test_functional.py
3+
4+
Note: This feature is currently unused! It is kept here for archival purposes.
5+
6+
Usage: pytest benchmarking/int8/row_scale_benchmark.py
7+
"""
8+
9+
import time
10+
11+
import pytest
12+
import torch
13+
14+
from bitsandbytes import functional as F
15+
16+
k = 20
17+
torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
18+
19+
20+
@pytest.mark.parametrize(
21+
("dim1", "dim4", "inner"),
22+
[
23+
pytest.param(1024, 12288 * 4, 12288, id="1024, 12288*4, 12288"),
24+
pytest.param(2048, 4096 * 4, 4096, id="2048, 4096*4, 4096"),
25+
],
26+
)
27+
@pytest.mark.skip("Row scale has some bugs for ampere")
28+
@pytest.mark.benchmark
29+
def test_row_scale_bench(dim1, dim4, inner):
30+
formatB = F.get_special_format_str()
31+
err1, err2, err3 = [], [], []
32+
relerr1, relerr2 = [], []
33+
scale = 1
34+
A = torch.randn(dim1, inner, device="cuda").half()
35+
B = torch.randn(dim4, inner, device="cuda").half()
36+
torch.nn.init.xavier_uniform_(B)
37+
# warmpup
38+
for i in range(k):
39+
C1 = torch.matmul(A, B.t())
40+
41+
torch.cuda.synchronize()
42+
t0 = time.time()
43+
for i in range(k):
44+
C1 = torch.matmul(A, B.t())
45+
torch.cuda.synchronize()
46+
print("16", time.time() - t0)
47+
48+
C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A)
49+
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
50+
A2, SA = F.nvidia_transform(C1a, "col32")
51+
B2, SB = F.nvidia_transform(CB, formatB)
52+
A1, maxA = F.vectorwise_quant(A, dim=1)
53+
54+
c = 10.0 * inner * scale
55+
row_scale = maxA / c
56+
torch.cuda.synchronize()
57+
t0 = time.time()
58+
for i in range(k):
59+
outC32 = F.int8_linear_matmul(A2, B2, dtype=torch.int8, row_scale=row_scale)
60+
torch.cuda.synchronize()
61+
print("row-wise", time.time() - t0)
62+
63+
C2a, C2b, stats2a, stats2b, coo_tensor = F.int8_double_quant(B)
64+
B2, SB = F.nvidia_transform(C2a, formatB)
65+
torch.cuda.synchronize()
66+
t0 = time.time()
67+
for i in range(k):
68+
outC32 = F.int8_linear_matmul(A2, B2)
69+
torch.cuda.synchronize()
70+
print("vector-wise", time.time() - t0)
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""
2+
Extracted from tests/test_functional.py
3+
4+
Usage: pytest benchmarking/int8/training_benchmark.py
5+
"""
6+
7+
import time
8+
9+
import pytest
10+
import torch
11+
12+
from bitsandbytes import functional as F
13+
14+
k = 20
15+
16+
17+
@pytest.mark.parametrize(
18+
("batch", "seq", "model", "hidden"),
19+
[
20+
pytest.param(2, 512, 4 * 1024, 3 * 4 * 1024, id="batch=2, seq=512, model=4k, hidden=12k"),
21+
pytest.param(2, 512, 5120, 3 * 5120, id="batch=2, seq=512, model=5k, hidden=15k"),
22+
pytest.param(2, 512, 12 * 1024, 4 * 12 * 1024, id="batch=2, seq=512, model=12k, hidden=48k"),
23+
],
24+
)
25+
@pytest.mark.benchmark
26+
def test_bench_8bit_training(batch, seq, model, hidden):
27+
formatB = F.get_special_format_str()
28+
A = torch.randn(batch, seq, model, device="cuda").half()
29+
grad = torch.randn(batch, seq, model, device="cuda").half()
30+
w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
31+
w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
32+
print("")
33+
34+
# torch.cuda.synchronize()
35+
## warmup
36+
# for i in range(100):
37+
# torch.matmul(A, w1.t())
38+
# torch.cuda.synchronize()
39+
40+
dtype = torch.int8
41+
A = A.view(-1, A.shape[-1]).contiguous()
42+
grad = grad.view(-1, grad.shape[-1]).contiguous()
43+
torch.cuda.synchronize()
44+
t0 = time.time()
45+
for i in range(k):
46+
out1 = torch.matmul(A, w1.t()) # fc1
47+
# out2 = torch.matmul(out1, w2.t())# fc2
48+
49+
# d1 = torch.matmul(grad, w2) # delta1
50+
# d2 = torch.matmul(d1, w1) # delta2
51+
52+
# grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
53+
# grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
54+
55+
torch.cuda.synchronize()
56+
t16 = time.time() - t0
57+
print(t16)
58+
59+
# torch.cuda.empty_cache()
60+
61+
# Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
62+
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
63+
64+
# CTw1, Sw1 = F.transform2(Cw1, formatB)
65+
# CTw2, Sw2 = F.transform2(Cw2, formatB)
66+
# CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
67+
# CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
68+
69+
# CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
70+
# C32A, SA = F.transform2(CA, 'col32')
71+
## fc1
72+
# out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
73+
##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
74+
75+
## fc2
76+
# Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
77+
# C32out1, Sout1 = F.transform2(Cout1, 'col32')
78+
# out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
79+
##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
80+
81+
## delta1
82+
# Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
83+
# C32grad, Sgrad = F.transform2(Cgrad, 'col32')
84+
##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
85+
##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
86+
87+
## delta2
88+
# Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
89+
# C32d1, Sd1 = F.transform2(Cd1, 'col32')
90+
##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
91+
##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
92+
93+
## grad1
94+
# C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
95+
# CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
96+
##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
97+
##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
98+
99+
## grad2
100+
# C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
101+
# CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
102+
##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
103+
##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
104+
105+
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
106+
107+
# Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
108+
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
109+
110+
# CTw1, Sw1 = F.transform2(Cw1, formatB)
111+
# CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
112+
# CTw2, Sw2 = F.transform2(Cw2, formatB)
113+
# CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
114+
# torch.cuda.synchronize()
115+
# t0 = time.time()
116+
# for i in range(k):
117+
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
118+
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
119+
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
120+
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
121+
122+
# #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
123+
# CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
124+
# #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
125+
# #CTw2, Sw2 = F.transform2(Cw2, formatB)
126+
# #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
127+
128+
# C32A, SA = F.transform2(CA, 'col32')
129+
130+
# # fc1
131+
# out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
132+
# #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
133+
134+
# #print(coo_tensor.nnz)
135+
# #out1sp = F.spmm_coo(coo_tensor, w1.t())
136+
# #print(w1.t().shape)
137+
# #out1 = out1dn + out1sp
138+
139+
# # fc2
140+
# Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
141+
# C32out1, Sout1 = F.transform2(Cout1, 'col32')
142+
# out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
143+
# #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)
144+
145+
# # delta1
146+
# Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
147+
# C32grad, Sgrad = F.transform2(Cgrad, 'col32')
148+
# d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
149+
# #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)
150+
151+
# # delta2
152+
# Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
153+
# C32d1, Sd1 = F.transform2(Cd1, 'col32')
154+
# d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
155+
# #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)
156+
157+
# # grad1
158+
# #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
159+
# #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
160+
# #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
161+
# #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)
162+
163+
# ## grad2
164+
# #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
165+
# #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
166+
# #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
167+
# #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
168+
169+
# torch.cuda.synchronize()
170+
# t8 = time.time() - t0
171+
# print(t8)

0 commit comments

Comments
 (0)