Skip to content

Commit f64cfe6

Browse files
committed
Fixed prefetch bug for non-paged tensors; added benchmark.
1 parent 41a9c70 commit f64cfe6

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

bitsandbytes/optim/optimizer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,9 +314,12 @@ def get_state_buffer(self, p, dtype=torch.float32):
314314
def prefetch_state(self, p):
315315
if self.is_paged:
316316
state = self.state[p]
317-
F.prefetch_tensor(state['state1'])
318-
if 'state2' in state:
319-
F.prefetch_tensor(state['state2'])
317+
s1 = state['state1']
318+
is_paged = getattr(s1, 'is_paged', False)
319+
if is_paged:
320+
F.prefetch_tensor(state['state1'])
321+
if 'state2' in state:
322+
F.prefetch_tensor(state['state2'])
320323

321324

322325
class Optimizer2State(Optimizer8bit):

tests/test_optim.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,3 +490,47 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
490490
params = (k - k // 5) * dim1 * dim2
491491
print(optim_name, gtype, s / params)
492492
# assert s < 3.9
493+
494+
dim1 = [10*1024]
495+
gtype = [torch.float16]
496+
#mode = ['torch', 'bnb']
497+
mode = ['bnb']
498+
optimizer_names = ['paged_adamw']
499+
#optimizer_names = ['paged_adamw8bit_blockwise']
500+
values = list(product(dim1,gtype, optimizer_names, mode))
501+
names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values]
502+
@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names)
503+
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
504+
layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
505+
layers1 = layers1.to(gtype)
506+
layers1 = layers1.cuda()
507+
508+
large_tensor = None
509+
if mode == 'torch':
510+
optim = str2optimizers[optim_name][0](layers1.parameters())
511+
else:
512+
optim = str2optimizers[optim_name][1](layers1.parameters())
513+
# 12 GB
514+
large_tensor = torch.empty((int(4.5e9),), device='cuda')
515+
516+
torch.cuda.synchronize()
517+
time.sleep(5)
518+
519+
num_batches = 5
520+
batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype)
521+
lbls = torch.randint(0, 10, size=(num_batches,128)).cuda()
522+
523+
for i in range(num_batches):
524+
print(i)
525+
b = batches[i]
526+
if i ==2:
527+
torch.cuda.synchronize()
528+
t0 = time.time()
529+
530+
out1 = layers1(b)
531+
532+
loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean()
533+
loss1.backward()
534+
optim.step()
535+
torch.cuda.synchronize()
536+
print(mode, time.time() - t0)

0 commit comments

Comments
 (0)